Skip to content

Commit

Permalink
fix parameter sharding in testing script (#7496)
Browse files Browse the repository at this point in the history
  • Loading branch information
JackCaoG authored Jun 25, 2024
1 parent e2868f1 commit 4caead7
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions test/spmd/test_sharding_strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,6 @@
axis_names=('data', 'fsdp', 'tensor'))

data_sharding = (('data', 'fsdp'), 'tensor')
# We assume parameters are stored in a decreasing order of dimension size
parameter_sharding = ('tensor', 'fsdp')


def gen_data(batch, d_emb):
Expand Down Expand Up @@ -164,8 +162,10 @@ def training_step(data):
xm.mark_step()

for name, layer in model.named_modules():
if 'linear' in name:
xs.mark_sharding(layer.weight, mesh, parameter_sharding)
if 'EMB2FF_linear' in name:
xs.mark_sharding(layer.weight, mesh, ('tensor', 'fsdp'))
elif 'FF2EMB_linear' in name:
xs.mark_sharding(layer.weight, mesh, ('fsdp', 'tensor'))

optimizer = torch.optim.SGD(model.parameters(), lr=0.1)

Expand Down

0 comments on commit 4caead7

Please sign in to comment.