diff --git a/examples/decoder_only_model.py b/examples/decoder_only_model.py index 0cb8e640d98..de226ab1fff 100644 --- a/examples/decoder_only_model.py +++ b/examples/decoder_only_model.py @@ -7,13 +7,14 @@ from torch import nn +# the default config is intentionally kept low to make it runable on a sigle tpu v2-8 core. @dataclass class DecoderOnlyConfig: - hidden_size: int = 1024 + hidden_size: int = 512 num_hidden_layers: int = 2 num_attention_heads: int = 8 num_key_value_heads: int = 4 - intermediate_size = 32 * 1024 + intermediate_size = 32 * 512 vocab_size = 3200 use_flash_attention = False