Skip to content

Commit

Permalink
Add test for Gemma model (pytorch#6631)
Browse files Browse the repository at this point in the history
  • Loading branch information
qihqi authored and amithrm committed Mar 1, 2024
1 parent 9314284 commit c0bc50f
Show file tree
Hide file tree
Showing 8 changed files with 795 additions and 7 deletions.
1 change: 1 addition & 0 deletions .github/workflows/torch_xla2.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ jobs:
run: |
pip install pytest absl-py jax[cpu] flatbuffers tensorflow
pip install torch --index-url https://download.pytorch.org/whl/cpu
pip install -r dev-requirements.txt
pip install -e .
- name: Run tests
working-directory: experimental/torch_xla2
Expand Down
15 changes: 8 additions & 7 deletions experimental/torch_xla2/dev-requirements.txt
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
-r requirements.txt
absl-py==2.0.0
flatbuffers==23.5.26
jax==0.4.23
jaxlib==0.4.23
pytest
yapf
tabulate
transformers
tf-nightly
--pre -f https://download.pytorch.org/whl/nightly/torch_nightly.html
torchvision
tensorflow
torch==2.2.1+cpu
immutabledict
sentencepiece
Empty file.
86 changes: 86 additions & 0 deletions experimental/torch_xla2/test/gemma/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# From: https://github.com/google/gemma_pytorch/blob/main/gemma/config.py

# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Gemma model config."""

import dataclasses
import immutabledict
import torch
from typing import Optional


# Keep a mapping from dtype strings to the supported torch dtypes.
_STR_DTYPE_TO_TORCH_DTYPE = immutabledict.immutabledict({
'float16': torch.float16,
'float': torch.float32,
'float32': torch.float32,
'bfloat16': torch.bfloat16,
})


@dataclasses.dataclass
class GemmaConfig:
# The number of tokens in the vocabulary.
vocab_size: int = 256000
# The maximum sequence length that this model might ever be used with.
max_position_embeddings: int = 8192
# The number of blocks in the model.
num_hidden_layers: int = 28
# The number of attention heads used in the attention layers of the model.
num_attention_heads: int = 16
# The number of key-value heads for implementing attention.
num_key_value_heads: int = 16
# The hidden size of the model.
hidden_size: int = 3072
# The dimension of the MLP representations.
intermediate_size: int = 24576
# The number of head dimensions.
head_dim: int = 256
# The epsilon used by the rms normalization layers.
rms_norm_eps: float = 1e-6
# The dtype of the weights.
dtype: str = 'bfloat16'
# Whether a quantized version of the model is used.
quant: bool = False
# The path to the model tokenizer.
tokenizer: Optional[str] = 'tokenizer/tokenizer.model'

def get_dtype(self) -> Optional[torch.dtype]:
"""Gets the torch dtype from the config dtype string."""
return _STR_DTYPE_TO_TORCH_DTYPE.get(self.dtype, None)


def get_config_for_7b() -> GemmaConfig:
return GemmaConfig()


def get_config_for_2b() -> GemmaConfig:
return GemmaConfig(
num_hidden_layers=18,
num_attention_heads=8,
num_key_value_heads=1,
hidden_size=2048,
intermediate_size=16384
)


def get_model_config(variant: str) -> GemmaConfig:
if variant == '7b':
return get_config_for_7b()
elif variant == '2b':
return get_config_for_2b()
return ValueError(f'Invalid variant {variant}. Supported variants are "2b"'
'and "7b"')
Loading

0 comments on commit c0bc50f

Please sign in to comment.