diff --git a/experimental/torch_xla2/examples/train_llama/utils.py b/experimental/torch_xla2/examples/train_llama/utils.py index 2bb05fa54a4..cf796992ab7 100644 --- a/experimental/torch_xla2/examples/train_llama/utils.py +++ b/experimental/torch_xla2/examples/train_llama/utils.py @@ -16,10 +16,10 @@ SEQLEN = 8192 BATCH = 8 -global_axis: Tuple[str, str] = ('devices', 'fsdp') +global_axis: Tuple[str, str] = ('fsdp', ) num_global_devices = jax.device_count() num_local_devices = jax.local_device_count() -num_partitions = (num_global_devices//num_local_devices, num_local_devices) +num_partitions = (num_global_devices, ) #SEQLEN = 512 import torch @@ -55,16 +55,6 @@ def sharded_device_put(tensor, sharding): return jax.device_put(tensor, sharding) shape = tensor.shape - if shape[0] == 1: - #hotfix weight.shape ~= (1, 6144, 4096) during num_layers == 1 case - #NOTE: maybe 'addressable_devices_indices_map' should ignore empty dimensions (dimension size == 1) - mesh = jax.sharding.Mesh( - mesh_utils.create_device_mesh(num_partitions), - axis_names=global_axis, - ) - - sharding = jax.sharding.NamedSharding(mesh, P(None, *(global_axis[len(shape)-1:0:-1]))) - x_split = [jax.device_put(tensor[i], device) for device, i in sharding.addressable_devices_indices_map(shape).items()] return jax.make_array_from_single_device_arrays(shape, sharding, x_split)