diff --git a/.github/workflows/torch_xla2.yml b/.github/workflows/torch_xla2.yml index 8099eb0bd2a..961bdd7101f 100644 --- a/.github/workflows/torch_xla2.yml +++ b/.github/workflows/torch_xla2.yml @@ -36,6 +36,7 @@ jobs: run: | pip install pytest absl-py jax[cpu] flatbuffers tensorflow pip install torch --index-url https://download.pytorch.org/whl/cpu + pip install -r dev-requirements.txt pip install -e . - name: Run tests working-directory: experimental/torch_xla2 diff --git a/experimental/torch_xla2/dev-requirements.txt b/experimental/torch_xla2/dev-requirements.txt index 6c6cb4d208e..4a32310fbda 100644 --- a/experimental/torch_xla2/dev-requirements.txt +++ b/experimental/torch_xla2/dev-requirements.txt @@ -1,8 +1,9 @@ --r requirements.txt +absl-py==2.0.0 +flatbuffers==23.5.26 +jax==0.4.23 +jaxlib==0.4.23 pytest -yapf -tabulate -transformers -tf-nightly ---pre -f https://download.pytorch.org/whl/nightly/torch_nightly.html -torchvision \ No newline at end of file +tensorflow +torch==2.2.1+cpu +immutabledict +sentencepiece \ No newline at end of file diff --git a/experimental/torch_xla2/test/gemma/__init__.py b/experimental/torch_xla2/test/gemma/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/experimental/torch_xla2/test/gemma/config.py b/experimental/torch_xla2/test/gemma/config.py new file mode 100644 index 00000000000..1ec3ef781c4 --- /dev/null +++ b/experimental/torch_xla2/test/gemma/config.py @@ -0,0 +1,86 @@ +# From: https://github.com/google/gemma_pytorch/blob/main/gemma/config.py + +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Gemma model config.""" + +import dataclasses +import immutabledict +import torch +from typing import Optional + + +# Keep a mapping from dtype strings to the supported torch dtypes. +_STR_DTYPE_TO_TORCH_DTYPE = immutabledict.immutabledict({ + 'float16': torch.float16, + 'float': torch.float32, + 'float32': torch.float32, + 'bfloat16': torch.bfloat16, +}) + + +@dataclasses.dataclass +class GemmaConfig: + # The number of tokens in the vocabulary. + vocab_size: int = 256000 + # The maximum sequence length that this model might ever be used with. + max_position_embeddings: int = 8192 + # The number of blocks in the model. + num_hidden_layers: int = 28 + # The number of attention heads used in the attention layers of the model. + num_attention_heads: int = 16 + # The number of key-value heads for implementing attention. + num_key_value_heads: int = 16 + # The hidden size of the model. + hidden_size: int = 3072 + # The dimension of the MLP representations. + intermediate_size: int = 24576 + # The number of head dimensions. + head_dim: int = 256 + # The epsilon used by the rms normalization layers. + rms_norm_eps: float = 1e-6 + # The dtype of the weights. + dtype: str = 'bfloat16' + # Whether a quantized version of the model is used. + quant: bool = False + # The path to the model tokenizer. + tokenizer: Optional[str] = 'tokenizer/tokenizer.model' + + def get_dtype(self) -> Optional[torch.dtype]: + """Gets the torch dtype from the config dtype string.""" + return _STR_DTYPE_TO_TORCH_DTYPE.get(self.dtype, None) + + +def get_config_for_7b() -> GemmaConfig: + return GemmaConfig() + + +def get_config_for_2b() -> GemmaConfig: + return GemmaConfig( + num_hidden_layers=18, + num_attention_heads=8, + num_key_value_heads=1, + hidden_size=2048, + intermediate_size=16384 + ) + + +def get_model_config(variant: str) -> GemmaConfig: + if variant == '7b': + return get_config_for_7b() + elif variant == '2b': + return get_config_for_2b() + return ValueError(f'Invalid variant {variant}. Supported variants are "2b"' + 'and "7b"') diff --git a/experimental/torch_xla2/test/gemma/model.py b/experimental/torch_xla2/test/gemma/model.py new file mode 100644 index 00000000000..08bf41c36bd --- /dev/null +++ b/experimental/torch_xla2/test/gemma/model.py @@ -0,0 +1,561 @@ +# From: https://raw.githubusercontent.com/google/gemma_pytorch/main/gemma/model.py + +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only Gemma model implementation.""" + +import re +import torch +from torch import nn +import torch.nn.functional as F +from typing import Any, List, Optional, Sequence, Tuple, Union + +from . import config as gemma_config +from . import tokenizer + + +class Sampler(nn.Module): + + def __init__(self, vocab_size: int): + super().__init__() + self.vocab_size = vocab_size + + @torch.no_grad() + def forward( + self, + embedding: torch.Tensor, + hidden_states: torch.Tensor, + output_positions: torch.Tensor, + temperatures: torch.Tensor, + top_ps: torch.Tensor, + top_ks: torch.Tensor, + embedding_bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + # Select the last element for each sequence. + # (batch_size, input_len, hidden_size) -> (batch_size, hidden_size) + hidden_states = hidden_states.index_select( + 1, output_positions).squeeze(dim=1) + logits = torch.matmul(hidden_states, embedding.t()) + if embedding_bias is not None: + logits += embedding_bias + + if temperatures is None: + return torch.argmax(logits, dim=-1).squeeze(dim=-1) + + # Apply temperature scaling. + logits.div_(temperatures.unsqueeze(dim=1)) + + # Calculate probabilities with softmax. + probs = torch.softmax(logits, dim=-1, dtype=torch.float) + probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True) + + # Apply top-p, top-k. + probs_sum = torch.cumsum(probs_sort, dim=-1) + top_ps_mask = (probs_sum - probs_sort) > top_ps.unsqueeze(dim=1) + probs_sort = torch.where(top_ps_mask, 0, probs_sort) + + top_ks_mask = torch.arange(probs_idx.shape[-1], + device=probs_idx.device) + top_ks_mask = top_ks_mask.expand(probs_idx.shape[0], -1) + top_ks_mask = top_ks_mask >= top_ks.unsqueeze(dim=1) + probs_sort = torch.where(top_ks_mask, 0, probs_sort) + + # Re-normalization. + probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True)) + probs = torch.gather(probs_sort, + dim=-1, + index=torch.argsort(probs_idx, dim=-1)) + next_token_ids = torch.multinomial(probs, + num_samples=1, + replacement=True).squeeze(dim=-1) + return next_token_ids + + +def precompute_freqs_cis(dim: int, + end: int, + theta: float = 10000.0) -> torch.Tensor: + """Precomputes the frequency cis.""" + freqs = 1.0 / (theta**(torch.arange(0, dim, 2)[:(dim // 2)].float() / dim)) + t = torch.arange(end, device=freqs.device) + freqs = torch.outer(t, freqs).float() + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 + return freqs_cis + + +def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor: + """Applies the rotary embedding to the query and key tensors.""" + x_ = torch.view_as_complex( + torch.stack(torch.chunk(x.transpose(1, 2).float(), 2, dim=-1), + dim=-1)) + x_out = torch.view_as_real(x_ * freqs_cis).type_as(x) + x_out = torch.cat(torch.chunk(x_out, 2, dim=-1), dim=-2) + x_out = x_out.reshape(x_out.shape[0], x_out.shape[1], x_out.shape[2], + -1).transpose(1, 2) + return x_out + + +class Linear(nn.Module): + + def __init__(self, in_features: int, out_features: int, quant: bool): + super().__init__() + if quant: + self.weight = nn.Parameter( + torch.empty((out_features, in_features), dtype=torch.int8), + requires_grad=False, + ) + self.weight_scaler = nn.Parameter(torch.Tensor(out_features)) + else: + self.weight = nn.Parameter( + torch.empty((out_features, in_features)), + requires_grad=False, + ) + self.quant = quant + + def forward(self, x): + weight = self.weight + if self.quant: + weight = weight * self.weight_scaler.unsqueeze(-1) + output = F.linear(x, weight) + return output + + +class Embedding(nn.Module): + + def __init__(self, num_embeddings: int, embedding_dim: int, quant: bool): + super().__init__() + if quant: + self.weight = nn.Parameter( + torch.empty((num_embeddings, embedding_dim), dtype=torch.int8), + requires_grad=False, + ) + self.weight_scaler = nn.Parameter(torch.Tensor(num_embeddings)) + else: + self.weight = nn.Parameter( + torch.empty((num_embeddings, embedding_dim)), + requires_grad=False, + ) + self.quant = quant + + def forward(self, x): + weight = self.weight + if self.quant: + weight = weight * self.weight_scaler.unsqueeze(-1) + output = F.embedding(x, weight) + return output + + +class RMSNorm(torch.nn.Module): + + def __init__( + self, + dim: int, + eps: float = 1e-6, + add_unit_offset: bool = True, + ): + super().__init__() + self.eps = eps + self.add_unit_offset = add_unit_offset + self.weight = nn.Parameter(torch.zeros(dim)) + + def _norm(self, x): + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def forward(self, x): + x = self._norm(x.float()).type_as(x) + if self.add_unit_offset: + output = x * (1 + self.weight) + else: + output = x * self.weight + return output + + +class GemmaMLP(nn.Module): + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + quant: bool, + ): + super().__init__() + self.gate_proj = Linear(hidden_size, intermediate_size, quant) + self.up_proj = Linear(hidden_size, intermediate_size, quant) + self.down_proj = Linear(intermediate_size, hidden_size, quant) + + def forward(self, x): + gate = self.gate_proj(x) + gate = F.gelu(gate) + up = self.up_proj(x) + fuse = gate * up + outputs = self.down_proj(fuse) + return outputs + + +class GemmaAttention(nn.Module): + + def __init__( + self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + head_dim: int, + quant: bool, + ): + super().__init__() + + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + + assert self.num_heads % self.num_kv_heads == 0 + self.num_queries_per_kv = self.num_heads // self.num_kv_heads + + self.hidden_size = hidden_size + self.head_dim = head_dim + + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + + self.scaling = self.head_dim**-0.5 + + self.qkv_proj = Linear( + self.hidden_size, + (self.num_heads + 2 * self.num_kv_heads) * self.head_dim, + quant=quant) + self.o_proj = Linear( + self.num_heads * self.head_dim, + self.hidden_size, + quant=quant) + + def forward( + self, + hidden_states: torch.Tensor, + freqs_cis: torch.Tensor, + kv_write_indices: torch.Tensor, + kv_cache: Tuple[torch.Tensor, torch.Tensor], + mask: torch.Tensor, + ) -> torch.Tensor: + hidden_states_shape = hidden_states.shape + assert len(hidden_states_shape) == 3 + + batch_size, input_len, _ = hidden_states_shape + + qkv = self.qkv_proj(hidden_states) + xq, xk, xv = qkv.split([self.q_size, self.kv_size, self.kv_size], + dim=-1) + + xq = xq.view(batch_size, -1, self.num_heads, self.head_dim) + xk = xk.view(batch_size, -1, self.num_kv_heads, self.head_dim) + xv = xv.view(batch_size, -1, self.num_kv_heads, self.head_dim) + + # Positional embedding. + xq = apply_rotary_emb(xq, freqs_cis=freqs_cis) + xk = apply_rotary_emb(xk, freqs_cis=freqs_cis) + + # Write new kv cache. + # [batch_size, input_len, n_local_kv_heads, head_dim] + k_cache, v_cache = kv_cache + k_cache.index_copy_(1, kv_write_indices, xk) + v_cache.index_copy_(1, kv_write_indices, xv) + + key = k_cache + value = v_cache + if self.num_kv_heads != self.num_heads: + # [batch_size, max_seq_len, n_local_heads, head_dim] + key = torch.repeat_interleave(key, self.num_queries_per_kv, dim=2) + value = torch.repeat_interleave(value, + self.num_queries_per_kv, + dim=2) + + # [batch_size, n_local_heads, input_len, head_dim] + q = xq.transpose(1, 2) + # [batch_size, n_local_heads, max_seq_len, head_dim] + k = key.transpose(1, 2) + v = value.transpose(1, 2) + + # [batch_size, n_local_heads, input_len, max_seq_len] + scores = torch.matmul(q, k.transpose(2, 3)) * self.scaling + scores = scores + mask + scores = F.softmax(scores.float(), dim=-1).type_as(q) + + # [batch_size, n_local_heads, input_len, head_dim] + output = torch.matmul(scores, v) + + # [batch_size, input_len, hidden_dim] + output = (output.transpose(1, 2).contiguous().view( + batch_size, input_len, -1)) + output = self.o_proj(output) + return output + + +class GemmaDecoderLayer(nn.Module): + + def __init__( + self, + config: gemma_config.GemmaConfig, + ): + super().__init__() + self.self_attn = GemmaAttention( + hidden_size=config.hidden_size, + num_heads=config.num_attention_heads, + num_kv_heads=config.num_key_value_heads, + head_dim=config.head_dim, + quant=config.quant, + ) + self.mlp = GemmaMLP( + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + quant=config.quant, + ) + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + freqs_cis: torch.Tensor, + kv_write_indices: torch.Tensor, + kv_cache: Tuple[torch.Tensor, torch.Tensor], + mask: torch.Tensor, + ) -> torch.Tensor: + # Self Attention + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + hidden_states = self.self_attn( + hidden_states=hidden_states, + freqs_cis=freqs_cis, + kv_write_indices=kv_write_indices, + kv_cache=kv_cache, + mask=mask, + ) + hidden_states = residual + hidden_states + + # MLP + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states + + +class GemmaModel(nn.Module): + + def __init__(self, config: gemma_config.GemmaConfig): + super().__init__() + self.config = config + self.vocab_size = config.vocab_size + + self.layers = nn.ModuleList() + for _ in range(config.num_hidden_layers): + self.layers.append(GemmaDecoderLayer(config)) + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + freqs_cis: torch.Tensor, + kv_write_indices: torch.Tensor, + kv_caches: List[Tuple[torch.Tensor, torch.Tensor]], + mask: torch.Tensor, + ) -> torch.Tensor: + for i in range(len(self.layers)): + layer = self.layers[i] + hidden_states = layer( + hidden_states=hidden_states, + freqs_cis=freqs_cis, + kv_write_indices=kv_write_indices, + kv_cache=kv_caches[i], + mask=mask, + ) + hidden_states = self.norm(hidden_states) + return hidden_states + + +class GemmaForCausalLM(nn.Module): + + def __init__( + self, + config: gemma_config.GemmaConfig, + ): + super().__init__() + self.config = config + assert config.hidden_size % config.num_attention_heads == 0 + + max_seq_len = config.max_position_embeddings + head_dim = config.head_dim + vocab_size = config.vocab_size + + self.tokenizer = None #tokenizer.Tokenizer(config.tokenizer) + self.embedder = Embedding(vocab_size, config.hidden_size, config.quant) + self.model = GemmaModel(config) + self.sampler = Sampler(vocab_size) + + # Pre-compute rotary embedding table. + rope_theta = getattr(config, 'rope_theta', 10000) + freqs_cis = precompute_freqs_cis(head_dim, + max_seq_len * 2, + theta=rope_theta) + self.register_buffer('freqs_cis', freqs_cis) + + @torch.no_grad() + def forward( + self, + input_token_ids: torch.Tensor, + input_positions: torch.Tensor, + kv_write_indices: torch.Tensor, + kv_caches: List[Tuple[torch.Tensor, torch.Tensor]], + mask: torch.Tensor, + output_positions: torch.Tensor, + temperatures: torch.Tensor, + top_ps: torch.Tensor, + top_ks: torch.Tensor, + **kwargs, + ) -> torch.Tensor: + freqs_cis = self.freqs_cis.index_select(0, input_positions) + kv_write_indices = input_positions + + # [batch_size, input_len, hidden_size] + hidden_states = self.embedder(input_token_ids) + # Gemma normalizes the embedding by sqrt(hidden_size). + hidden_states = hidden_states * (self.config.hidden_size**0.5) + + hidden_states = self.model( + hidden_states=hidden_states, + freqs_cis=freqs_cis, + kv_write_indices=kv_write_indices, + kv_caches=kv_caches, + mask=mask, + ) + embedder_weight = self.embedder.weight + if self.config.quant: + embedder_weight = ( + embedder_weight * self.embedder.weight_scaler.unsqueeze(-1)) + #next_tokens = self.sampler( + return hidden_states + # return next_tokens + + def generate( + self, + prompts: Union[str, Sequence[str]], + device: Any, + output_len: int = 100, + temperature: float = 0.95, + top_p: float = 1.0, + top_k: int = 100, + ) -> Union[str, Sequence[str]]: + """Generates responses for given prompts using Gemma model.""" + # If a single prompt is provided, treat it as a batch of 1. + is_str_prompt = isinstance(prompts, str) + if is_str_prompt: + prompts = [prompts] + + batch_size = len(prompts) + prompt_tokens = [self.tokenizer.encode(prompt) for prompt in prompts] + min_prompt_len = min(len(p) for p in prompt_tokens) + max_prompt_len = max(len(p) for p in prompt_tokens) + max_seq_len = max_prompt_len + output_len + assert max_seq_len <= self.config.max_position_embeddings + + # build KV caches + kv_caches = [] + for _ in range(self.config.num_hidden_layers): + size = (batch_size, max_seq_len, self.config.num_key_value_heads, + self.config.head_dim) + dtype = self.config.get_dtype() + k_cache = torch.zeros(size=size, dtype=dtype, device=device) + v_cache = torch.zeros(size=size, dtype=dtype, device=device) + kv_caches.append((k_cache, v_cache)) + + # prepare inputs + token_ids_tensor = torch.full((batch_size, max_seq_len), + self.tokenizer.pad_id, dtype=torch.int64) + input_token_ids_tensor = torch.full((batch_size, min_prompt_len), + self.tokenizer.pad_id, + dtype=torch.int64) + for i, p in enumerate(prompt_tokens): + token_ids_tensor[i, :len(p)] = torch.tensor(p) + input_token_ids_tensor[i, :min_prompt_len] = torch.tensor( + p[:min_prompt_len]) + token_ids_tensor = token_ids_tensor.to(device) + input_token_ids_tensor = input_token_ids_tensor.to(device) + prompt_mask_tensor = token_ids_tensor != self.tokenizer.pad_id + input_positions_tensor = torch.arange(0, min_prompt_len, + dtype=torch.int64).to(device) + mask_tensor = torch.full((1, 1, max_seq_len, max_seq_len), + -2.3819763e38).to(torch.float) + mask_tensor = torch.triu(mask_tensor, diagonal=1).to(device) + curr_mask_tensor = mask_tensor.index_select(2, input_positions_tensor) + output_positions_tensor = torch.LongTensor([min_prompt_len - 1]).to( + device) + temperatures_tensor = torch.FloatTensor([temperature] * batch_size).to( + device) + top_ps_tensor = torch.FloatTensor([top_p] * batch_size).to(device) + top_ks_tensor = torch.LongTensor([top_k] * batch_size).to(device) + output_index = torch.tensor(min_prompt_len, dtype=torch.int64).to( + device) + + # Prefill up to min_prompt_len tokens, then treat other prefill as + # decode and ignore output. + for i in range(max_seq_len - min_prompt_len): + next_token_ids = self( + input_token_ids=input_token_ids_tensor, + input_positions=input_positions_tensor, + kv_write_indices=None, + kv_caches=kv_caches, + mask=curr_mask_tensor, + output_positions=output_positions_tensor, + temperatures=temperatures_tensor, + top_ps=top_ps_tensor, + top_ks=top_ks_tensor, + ) + + curr_prompt_mask = prompt_mask_tensor.index_select( + 1, output_index).squeeze(dim=1) + curr_token_ids = token_ids_tensor.index_select( + 1, output_index).squeeze(dim=1) + output_token_ids = torch.where(curr_prompt_mask, curr_token_ids, + next_token_ids).unsqueeze(dim=1) + token_ids_tensor.index_copy_(1, output_index, output_token_ids) + + input_token_ids_tensor = output_token_ids + input_positions_tensor = output_index.unsqueeze(dim=-1) + curr_mask_tensor = mask_tensor.index_select(2, + input_positions_tensor) + output_positions_tensor = torch.tensor(0, dtype=torch.int64).to( + device) + output_index = output_index + 1 + + # Detokenization. + token_ids = token_ids_tensor.tolist() + results = [] + for i, tokens in enumerate(token_ids): + trimmed_output = tokens[len(prompt_tokens[i]):len(prompt_tokens[i]) + + output_len] + if self.tokenizer.eos_id in trimmed_output: + eos_index = trimmed_output.index(self.tokenizer.eos_id) + trimmed_output = trimmed_output[:eos_index] + results.append(self.tokenizer.decode(trimmed_output)) + + # If a string was provided as input, return a string as output. + return results[0] if is_str_prompt else results + + def load_weights(self, model_path: str): + self.load_state_dict( + torch.load( + model_path, mmap=True, weights_only=True, + )['model_state_dict'], + strict=False, + ) \ No newline at end of file diff --git a/experimental/torch_xla2/test/gemma/test_gemma.py b/experimental/torch_xla2/test/gemma/test_gemma.py new file mode 100644 index 00000000000..bd0bb21dbb1 --- /dev/null +++ b/experimental/torch_xla2/test/gemma/test_gemma.py @@ -0,0 +1,86 @@ +import torch +import unittest +import torch_xla2 +from torch.utils import _pytree as pytree +from . import config +from . import model as gemma + +class GemmaTest(unittest.TestCase): + + def setup(self): + torch.manual_seed(0) + + def test_gemma(self): + mconfig = config.GemmaConfig( + num_hidden_layers=3, + num_attention_heads=8, + num_key_value_heads=1, + hidden_size=256, + intermediate_size=16384, + dtype=torch.float32 + ) + model = gemma.GemmaForCausalLM(mconfig) + batch_size = 1 + max_seq_len = 1000 + min_prompt_len = 1000 + device = 'cpu' + pad_id = -1 + temperature = 0.8 + top_k = 100 + top_p = 1.0 + + # prepare inputs + token_ids_tensor = torch.randint(0, max_seq_len, (batch_size, max_seq_len), + dtype=torch.int64) + + # build KV caches + kv_caches = [] + for _ in range(model.config.num_hidden_layers): + size = (batch_size, max_seq_len, model.config.num_key_value_heads, + model.config.head_dim) + dtype = model.config.get_dtype() + k_cache = torch.zeros(size=size, dtype=dtype, device=device) + v_cache = torch.zeros(size=size, dtype=dtype, device=device) + kv_caches.append((k_cache, v_cache)) + + token_ids_tensor = token_ids_tensor.to(device) + prompt_mask_tensor = torch.ones_like(token_ids_tensor) + input_positions_tensor = torch.arange(0, min_prompt_len, + dtype=torch.int64).to(device) + mask_tensor = torch.full((1, 1, max_seq_len, max_seq_len), + -2.3819763e38).to(torch.float) + mask_tensor = torch.triu(mask_tensor, diagonal=1).to(device) + curr_mask_tensor = mask_tensor.index_select(2, input_positions_tensor) + output_positions_tensor = torch.LongTensor([min_prompt_len - 1]).to( + device) + temperatures_tensor = torch.FloatTensor([temperature] * batch_size).to( + device) + top_ps_tensor = torch.FloatTensor([top_p] * batch_size).to(device) + top_ks_tensor = torch.LongTensor([top_k] * batch_size).to(device) + output_index = torch.tensor(min_prompt_len, dtype=torch.int64).to( + device) + + inputs = ( + token_ids_tensor, + input_positions_tensor, + None, # kv_write_indexes + kv_caches, + mask_tensor, + output_positions_tensor, + temperatures_tensor, + top_ps_tensor, + top_ks_tensor, + ) + + weights, jax_func = torch_xla2.extract_jax(model) + inputs_jax = pytree.tree_map_only( + torch.Tensor, torch_xla2.tensor.move_to_device, inputs) + + import jax + print(jax.jit(jax_func)(weights, inputs_jax)) + + +if __name__ == '__main__': + unittest.main() + + diff --git a/experimental/torch_xla2/test/gemma/tokenizer.py b/experimental/torch_xla2/test/gemma/tokenizer.py new file mode 100644 index 00000000000..d421b012816 --- /dev/null +++ b/experimental/torch_xla2/test/gemma/tokenizer.py @@ -0,0 +1,48 @@ +# From: https://github.com/google/gemma_pytorch/blob/main/gemma/tokenizer.py + +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from typing import List, Optional + +from sentencepiece import SentencePieceProcessor + + +class Tokenizer: + + def __init__(self, model_path: Optional[str]): + # Reload tokenizer. + assert os.path.isfile(model_path), model_path + self.sp_model = SentencePieceProcessor(model_file=model_path) + + # BOS / EOS token IDs. + self.n_words: int = self.sp_model.vocab_size() + self.bos_id: int = self.sp_model.bos_id() + self.eos_id: int = self.sp_model.eos_id() + self.pad_id: int = self.sp_model.pad_id() + assert self.sp_model.vocab_size() == self.sp_model.get_piece_size() + + def encode(self, s: str, bos: bool = True, eos: bool = False) -> List[int]: + """Converts a string into a list of tokens.""" + assert isinstance(s, str) + t = self.sp_model.encode(s) + if bos: + t = [self.bos_id] + t + if eos: + t = t + [self.eos_id] + return t + + def decode(self, t: List[int]) -> str: + """Converts a list of tokens into a string.""" + return self.sp_model.decode(t) diff --git a/experimental/torch_xla2/torch_xla2/ops.py b/experimental/torch_xla2/torch_xla2/ops.py index 2a094c98180..334c63bb423 100644 --- a/experimental/torch_xla2/torch_xla2/ops.py +++ b/experimental/torch_xla2/torch_xla2/ops.py @@ -239,6 +239,11 @@ def _aten_div(x, y, rounding_mode=""): res = jnp.trunc(res) return res +@op(torch.ops.aten.div_, is_jax_func=False) +def _aten_div_(x, y, rounding_mode=""): + x._elem = _aten_div(x._elem, y._elem, rounding_mode) + return x + @op(torch.ops.aten.true_divide) def _aten_true_divide(x, y):