Skip to content

Commit

Permalink
[FSDPv2] Shard on the maximal dim of weights (#7134)
Browse files Browse the repository at this point in the history
Summary:
This pull request makes FSDPv2 to shard on the maximal dim of weights instead of the 0th dim.

Test Plan:
XLA_USE_SPMD=1 PJRT_DEVICE=TPU python test/spmd/test_fsdp_v2.py
  • Loading branch information
alanwaketan authored May 29, 2024
1 parent fb37312 commit 15fc0f1
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 12 deletions.
6 changes: 3 additions & 3 deletions test/spmd/test_fsdp_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def test_fsdp_v2_basic(self):

# Make sure all weights are sharded.
if self.n_devices > 1:
annotation = '{devices=[%d,1]%s}' % (self.n_devices, ','.join(
annotation = '{devices=[1,%d]%s}' % (self.n_devices, ','.join(
[str(i) for i in range(self.n_devices)]))
self.assertEqual(annotation,
torch_xla._XLAC._get_xla_sharding_spec(model.fc1.weight))
Expand Down Expand Up @@ -147,9 +147,9 @@ def test_fsdp_v2_multi_slice(self):
model = FSDPv2(model, mesh=mesh, extra_data_axis="data")

# Make sure all weights are sharded.
annotation = '{devices=[2,1,2]0,2,1,3 last_tile_dim_replicate}'
annotation = '{devices=[1,2,2]0,2,1,3 last_tile_dim_replicate}'
if self.n_devices == 8:
annotation = '{devices=[4,1,2]0,4,1,5,2,6,3,7 last_tile_dim_replicate}'
annotation = '{devices=[1,4,2]0,4,1,5,2,6,3,7 last_tile_dim_replicate}'
self.assertEqual(annotation,
torch_xla._XLAC._get_xla_sharding_spec(model.fc1.weight))
self.assertEqual(annotation,
Expand Down
24 changes: 15 additions & 9 deletions torch_xla/experimental/spmd_fully_sharded_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,24 @@
from torch_xla.distributed.fsdp.wrap import recursive_wrap


def _prepare_spmd_partition_spec(param, extra_data_axis=None):
partition_spec = [None] * len(param.shape)
def _prepare_spmd_partition_spec(param,
extra_data_axis=None,
shard_maximal=False):
shape = param.shape
partition_spec = [None] * len(shape)
# Skip scalar tensors and it replicated.
if len(partition_spec) == 0:
return partition_spec

# Only shard the 0th dimension of the parameter according to the
# fsdp axis of the mesh.
# TODO: should we shard on the maximal dim for param? Then we need
# another helper for the output.
partition_spec[0] = "fsdp"
# Shard the 0th dimension of the parameter according to the
# fsdp axis of the mesh, if shard_maximal is not specified.
index = 0
if shard_maximal:
index = shape.index(max(shape))

partition_spec[index] = "fsdp"
if extra_data_axis:
partition_spec[0] = (extra_data_axis, "fsdp")
partition_spec[index] = (extra_data_axis, "fsdp")
return tuple(partition_spec)


Expand Down Expand Up @@ -113,7 +118,8 @@ def __init__(
for param in module.parameters():
if torch_xla._XLAC._get_xla_sharding_spec(param) != "":
continue
spmd.mark_sharding(param, mesh, _prepare_spmd_partition_spec(param))
spmd.mark_sharding(
param, mesh, _prepare_spmd_partition_spec(param, shard_maximal=True))

# Register a backward hook to place optimization barrier to prevent
# gigantic fusions on syncing the gradients.
Expand Down

0 comments on commit 15fc0f1

Please sign in to comment.