Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add test for Gemma model #6631

Merged
merged 1 commit into from
Feb 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/torch_xla2.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ jobs:
run: |
pip install pytest absl-py jax[cpu] flatbuffers tensorflow
pip install torch --index-url https://download.pytorch.org/whl/cpu
pip install -r dev-requirements.txt
pip install -e .
- name: Run tests
working-directory: experimental/torch_xla2
Expand Down
15 changes: 8 additions & 7 deletions experimental/torch_xla2/dev-requirements.txt
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
-r requirements.txt
absl-py==2.0.0
flatbuffers==23.5.26
jax==0.4.23
jaxlib==0.4.23
pytest
yapf
tabulate
transformers
tf-nightly
--pre -f https://download.pytorch.org/whl/nightly/torch_nightly.html
torchvision
tensorflow
torch==2.2.1+cpu
immutabledict
sentencepiece
Empty file.
86 changes: 86 additions & 0 deletions experimental/torch_xla2/test/gemma/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# From: https://github.com/google/gemma_pytorch/blob/main/gemma/config.py

# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Gemma model config."""

import dataclasses
import immutabledict
import torch
from typing import Optional


# Keep a mapping from dtype strings to the supported torch dtypes.
_STR_DTYPE_TO_TORCH_DTYPE = immutabledict.immutabledict({
'float16': torch.float16,
'float': torch.float32,
'float32': torch.float32,
'bfloat16': torch.bfloat16,
})


@dataclasses.dataclass
class GemmaConfig:
# The number of tokens in the vocabulary.
vocab_size: int = 256000
# The maximum sequence length that this model might ever be used with.
max_position_embeddings: int = 8192
# The number of blocks in the model.
num_hidden_layers: int = 28
# The number of attention heads used in the attention layers of the model.
num_attention_heads: int = 16
# The number of key-value heads for implementing attention.
num_key_value_heads: int = 16
# The hidden size of the model.
hidden_size: int = 3072
# The dimension of the MLP representations.
intermediate_size: int = 24576
# The number of head dimensions.
head_dim: int = 256
# The epsilon used by the rms normalization layers.
rms_norm_eps: float = 1e-6
# The dtype of the weights.
dtype: str = 'bfloat16'
# Whether a quantized version of the model is used.
quant: bool = False
# The path to the model tokenizer.
tokenizer: Optional[str] = 'tokenizer/tokenizer.model'

def get_dtype(self) -> Optional[torch.dtype]:
"""Gets the torch dtype from the config dtype string."""
return _STR_DTYPE_TO_TORCH_DTYPE.get(self.dtype, None)


def get_config_for_7b() -> GemmaConfig:
return GemmaConfig()


def get_config_for_2b() -> GemmaConfig:
return GemmaConfig(
num_hidden_layers=18,
num_attention_heads=8,
num_key_value_heads=1,
hidden_size=2048,
intermediate_size=16384
)


def get_model_config(variant: str) -> GemmaConfig:
if variant == '7b':
return get_config_for_7b()
elif variant == '2b':
return get_config_for_2b()
return ValueError(f'Invalid variant {variant}. Supported variants are "2b"'
'and "7b"')
Loading
Loading