diff --git a/test/spmd/test_fsdp_v2.py b/test/spmd/test_fsdp_v2.py index 1429e377b18..cc161d0f1a3 100644 --- a/test/spmd/test_fsdp_v2.py +++ b/test/spmd/test_fsdp_v2.py @@ -32,7 +32,7 @@ def test_fsdp_v2_basic(self): # Make sure all weights are sharded. if self.n_devices > 1: - annotation = '{devices=[%d,1]%s}' % (self.n_devices, ','.join( + annotation = '{devices=[1,%d]%s}' % (self.n_devices, ','.join( [str(i) for i in range(self.n_devices)])) self.assertEqual(annotation, torch_xla._XLAC._get_xla_sharding_spec(model.fc1.weight)) @@ -147,9 +147,9 @@ def test_fsdp_v2_multi_slice(self): model = FSDPv2(model, mesh=mesh, extra_data_axis="data") # Make sure all weights are sharded. - annotation = '{devices=[2,1,2]0,2,1,3 last_tile_dim_replicate}' + annotation = '{devices=[1,2,2]0,2,1,3 last_tile_dim_replicate}' if self.n_devices == 8: - annotation = '{devices=[4,1,2]0,4,1,5,2,6,3,7 last_tile_dim_replicate}' + annotation = '{devices=[1,4,2]0,4,1,5,2,6,3,7 last_tile_dim_replicate}' self.assertEqual(annotation, torch_xla._XLAC._get_xla_sharding_spec(model.fc1.weight)) self.assertEqual(annotation, diff --git a/torch_xla/experimental/spmd_fully_sharded_data_parallel.py b/torch_xla/experimental/spmd_fully_sharded_data_parallel.py index 994f7e77dbe..142f9bc7561 100644 --- a/torch_xla/experimental/spmd_fully_sharded_data_parallel.py +++ b/torch_xla/experimental/spmd_fully_sharded_data_parallel.py @@ -13,19 +13,24 @@ from torch_xla.distributed.fsdp.wrap import recursive_wrap -def _prepare_spmd_partition_spec(param, extra_data_axis=None): - partition_spec = [None] * len(param.shape) +def _prepare_spmd_partition_spec(param, + extra_data_axis=None, + shard_maximal=False): + shape = param.shape + partition_spec = [None] * len(shape) # Skip scalar tensors and it replicated. if len(partition_spec) == 0: return partition_spec - # Only shard the 0th dimension of the parameter according to the - # fsdp axis of the mesh. - # TODO: should we shard on the maximal dim for param? Then we need - # another helper for the output. - partition_spec[0] = "fsdp" + # Shard the 0th dimension of the parameter according to the + # fsdp axis of the mesh, if shard_maximal is not specified. + index = 0 + if shard_maximal: + index = shape.index(max(shape)) + + partition_spec[index] = "fsdp" if extra_data_axis: - partition_spec[0] = (extra_data_axis, "fsdp") + partition_spec[index] = (extra_data_axis, "fsdp") return tuple(partition_spec) @@ -113,7 +118,8 @@ def __init__( for param in module.parameters(): if torch_xla._XLAC._get_xla_sharding_spec(param) != "": continue - spmd.mark_sharding(param, mesh, _prepare_spmd_partition_spec(param)) + spmd.mark_sharding( + param, mesh, _prepare_spmd_partition_spec(param, shard_maximal=True)) # Register a backward hook to place optimization barrier to prevent # gigantic fusions on syncing the gradients.