Skip to content

Commit

Permalink
[torch_ xla2] LLama3 example - fixed sharding configuration on LLama3…
Browse files Browse the repository at this point in the history
… example (#7973)
  • Loading branch information
zmelumian committed Sep 9, 2024
1 parent 12e5958 commit ef6b370
Showing 1 changed file with 2 additions and 12 deletions.
14 changes: 2 additions & 12 deletions experimental/torch_xla2/examples/train_llama/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@

SEQLEN = 8192
BATCH = 8
global_axis: Tuple[str, str] = ('devices', 'fsdp')
global_axis: Tuple[str, str] = ('fsdp', )
num_global_devices = jax.device_count()
num_local_devices = jax.local_device_count()
num_partitions = (num_global_devices//num_local_devices, num_local_devices)
num_partitions = (num_global_devices, )
#SEQLEN = 512

import torch
Expand Down Expand Up @@ -55,16 +55,6 @@ def sharded_device_put(tensor, sharding):
return jax.device_put(tensor, sharding)

shape = tensor.shape
if shape[0] == 1:
#hotfix weight.shape ~= (1, 6144, 4096) during num_layers == 1 case
#NOTE: maybe 'addressable_devices_indices_map' should ignore empty dimensions (dimension size == 1)
mesh = jax.sharding.Mesh(
mesh_utils.create_device_mesh(num_partitions),
axis_names=global_axis,
)

sharding = jax.sharding.NamedSharding(mesh, P(None, *(global_axis[len(shape)-1:0:-1])))

x_split = [jax.device_put(tensor[i], device) for device, i in sharding.addressable_devices_indices_map(shape).items()]
return jax.make_array_from_single_device_arrays(shape, sharding, x_split)

Expand Down

0 comments on commit ef6b370

Please sign in to comment.