Skip to content

Commit

Permalink
Misc fix for jax export, ran formatter (pytorch#6516)
Browse files Browse the repository at this point in the history
  • Loading branch information
qihqi authored and amithrm committed Mar 1, 2024
1 parent 9e6dc14 commit ac9aff1
Show file tree
Hide file tree
Showing 10 changed files with 558 additions and 454 deletions.
105 changes: 51 additions & 54 deletions experimental/torch_xla2/test/llama/llama_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,41 +50,46 @@ def from_name(cls, name: str):
return cls(**transformer_configs[name])
# fuzzy search
config = [
config
for config in transformer_configs
config for config in transformer_configs
if config in str(name).upper() or config in str(name)
]
assert len(config) == 1, name
return cls(**transformer_configs[config[0]])


transformer_configs = {
"CodeLlama-7b-Python-hf": dict(
block_size=16384,
vocab_size=32000,
n_layer=32,
dim=4096,
rope_base=1000000,
),
"7B": dict(n_layer=32, n_head=32, dim=4096),
"13B": dict(n_layer=40, n_head=40, dim=5120),
"30B": dict(n_layer=60, n_head=52, dim=6656),
"34B": dict(
n_layer=48,
n_head=64,
dim=8192,
vocab_size=32000,
n_local_heads=8,
intermediate_size=22016,
rope_base=1000000,
), # CodeLlama-34B-Python-hf
"70B": dict(
n_layer=80,
n_head=64,
dim=8192,
n_local_heads=8,
intermediate_size=28672,
),
"CodeLlama-7b-Python-hf":
dict(
block_size=16384,
vocab_size=32000,
n_layer=32,
dim=4096,
rope_base=1000000,
),
"7B":
dict(n_layer=32, n_head=32, dim=4096),
"13B":
dict(n_layer=40, n_head=40, dim=5120),
"30B":
dict(n_layer=60, n_head=52, dim=6656),
"34B":
dict(
n_layer=48,
n_head=64,
dim=8192,
vocab_size=32000,
n_local_heads=8,
intermediate_size=22016,
rope_base=1000000,
), # CodeLlama-34B-Python-hf
"70B":
dict(
n_layer=80,
n_head=64,
dim=8192,
n_local_heads=8,
intermediate_size=28672,
),
}


Expand Down Expand Up @@ -123,8 +128,7 @@ def __init__(self, config: ModelArgs) -> None:

self.tok_embeddings = nn.Embedding(config.vocab_size, config.dim)
self.layers = nn.ModuleList(
TransformerBlock(config) for _ in range(config.n_layer)
)
TransformerBlock(config) for _ in range(config.n_layer))
self.norm = RMSNorm(config.dim, eps=config.norm_eps)
self.output = nn.Linear(config.dim, config.vocab_size, bias=False)

Expand All @@ -134,28 +138,24 @@ def __init__(self, config: ModelArgs) -> None:
self.max_seq_length = -1

def setup_caches(self, max_batch_size, max_seq_length):
if (
self.max_seq_length >= max_seq_length
and self.max_batch_size >= max_batch_size
):
if (self.max_seq_length >= max_seq_length and
self.max_batch_size >= max_batch_size):
return
head_dim = self.config.dim // self.config.n_head
max_seq_length = find_multiple(max_seq_length, 8)
self.max_seq_length = max_seq_length
self.max_batch_size = max_batch_size
for b in self.layers:
b.attention.kv_cache = KVCache(
max_batch_size, max_seq_length, self.config.n_local_heads, head_dim
)
b.attention.kv_cache = KVCache(max_batch_size, max_seq_length,
self.config.n_local_heads, head_dim)

self.freqs_cis = precompute_freqs_cis(
self.config.block_size,
self.config.dim // self.config.n_head,
self.config.rope_base,
)
self.causal_mask = torch.tril(
torch.ones(self.max_seq_length, self.max_seq_length, dtype=torch.bool)
)
torch.ones(self.max_seq_length, self.max_seq_length, dtype=torch.bool))

def forward(self, idx: Tensor, input_pos: Optional[Tensor] = None) -> Tensor:
assert self.freqs_cis is not None, "Caches must be initialized first"
Expand Down Expand Up @@ -183,9 +183,8 @@ def __init__(self, config: ModelArgs) -> None:
self.ffn_norm = RMSNorm(config.dim, config.norm_eps)
self.attention_norm = RMSNorm(config.dim, config.norm_eps)

def forward(
self, x: Tensor, input_pos: Tensor, freqs_cis: Tensor, mask: Tensor
) -> Tensor:
def forward(self, x: Tensor, input_pos: Tensor, freqs_cis: Tensor,
mask: Tensor) -> Tensor:
h = x + self.attention(self.attention_norm(x), freqs_cis, mask, input_pos)
out = h + self.feed_forward(self.ffn_norm(h))
return out
Expand All @@ -197,9 +196,8 @@ def __init__(self, config: ModelArgs):
super().__init__()
assert config.dim % config.n_head == 0

total_head_dim = (
config.n_head + 2 * config.n_local_heads
) * config.head_dim
total_head_dim = (config.n_head +
2 * config.n_local_heads) * config.head_dim
# key, query, value projections for all heads, but in a batch
self.wqkv = nn.Linear(config.dim, total_head_dim, bias=False)
self.wo = nn.Linear(config.dim, config.dim, bias=False)
Expand Down Expand Up @@ -283,12 +281,11 @@ def forward(self, x: Tensor) -> Tensor:
return output * self.weight


def precompute_freqs_cis(
seq_len: int, n_elem: int, base: int = 10000
) -> Tensor:
def precompute_freqs_cis(seq_len: int,
n_elem: int,
base: int = 10000) -> Tensor:
freqs = 1.0 / (
base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem)
)
base**(torch.arange(0, n_elem, 2)[:(n_elem // 2)].float() / n_elem))
t = torch.arange(seq_len, device=freqs.device)
freqs = torch.outer(t, freqs)
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
Expand All @@ -301,10 +298,10 @@ def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor:
freqs_cis = freqs_cis.view(1, xshaped.size(1), 1, xshaped.size(3), 2)
x_out2 = torch.stack(
[
xshaped[..., 0] * freqs_cis[..., 0]
- xshaped[..., 1] * freqs_cis[..., 1],
xshaped[..., 1] * freqs_cis[..., 0]
+ xshaped[..., 0] * freqs_cis[..., 1],
xshaped[..., 0] * freqs_cis[..., 0] -
xshaped[..., 1] * freqs_cis[..., 1],
xshaped[..., 1] * freqs_cis[..., 0] +
xshaped[..., 0] * freqs_cis[..., 1],
],
-1,
)
Expand Down
3 changes: 1 addition & 2 deletions experimental/torch_xla2/test/llama/test_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,7 @@ def test_can_run(self):
@jax.jit
def m_func_jit(weights, buffer, args, causal_mask, freqs_cis):
weights, buffer, args, causal_mask, freqs_cis = tensor.wrap(
(weights, buffer, args, causal_mask, freqs_cis)
)
(weights, buffer, args, causal_mask, freqs_cis))
m_func.stateless_model.freqs_cis = freqs_cis
m_func.stateless_model.causal_mask = causal_mask
res = m_func(weights, buffer, *args)
Expand Down
39 changes: 20 additions & 19 deletions experimental/torch_xla2/test/test_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import torch_xla2
from . import test_base


class CustomConv1(torch.nn.Module):

def __init__(
Expand Down Expand Up @@ -53,27 +54,27 @@ def forward(self, x):

class ConvTest(test_base.TestCase):

def test_conv1(self):
m = CustomConv1()
arg = torch.randn((20, 1, 50))
res = m(arg)
def test_conv1(self):
m = CustomConv1()
arg = torch.randn((20, 1, 50))
res = m(arg)

jax_weights, jax_func = torch_xla2.extract_jax(m)
arg = torch_xla2.tensor.t2j(arg)
res2 = jax_func(jax_weights, (arg, ))
res2_torch = torch_xla2.tensor.j2t(res2)
self.assertTrue(torch.allclose(res, res2_torch))
jax_weights, jax_func = torch_xla2.extract_jax(m)
arg = torch_xla2.tensor.t2j(arg)
res2 = jax_func(jax_weights, (arg,))
res2_torch = torch_xla2.tensor.j2t(res2)
self.assertTrue(torch.allclose(res, res2_torch))

def test_conv2(self):
m = CustomConv2()
arg = torch.randn((20, 4, 50, 100))
res = m(arg)
jax_weights, jax_func = torch_xla2.extract_jax(m)
arg = torch_xla2.tensor.t2j(arg)
res2 = jax_func(jax_weights, (arg, ))
res2_torch = torch_xla2.tensor.j2t(res2)
self.assertTrue(torch.allclose(res, res2_torch, atol=1e-4, rtol=1e-4))
def test_conv2(self):
m = CustomConv2()
arg = torch.randn((20, 4, 50, 100))
res = m(arg)
jax_weights, jax_func = torch_xla2.extract_jax(m)
arg = torch_xla2.tensor.t2j(arg)
res2 = jax_func(jax_weights, (arg,))
res2_torch = torch_xla2.tensor.j2t(res2)
self.assertTrue(torch.allclose(res, res2_torch, atol=1e-4, rtol=1e-4))


if __name__ == '__main__':
test_base.main()
test_base.main()
Loading

0 comments on commit ac9aff1

Please sign in to comment.