diff --git a/test/spmd/test_sharding_strategies.py b/test/spmd/test_sharding_strategies.py index 4f869961f09..48a0fe4f03a 100644 --- a/test/spmd/test_sharding_strategies.py +++ b/test/spmd/test_sharding_strategies.py @@ -78,8 +78,6 @@ axis_names=('data', 'fsdp', 'tensor')) data_sharding = (('data', 'fsdp'), 'tensor') -# We assume parameters are stored in a decreasing order of dimension size -parameter_sharding = ('tensor', 'fsdp') def gen_data(batch, d_emb): @@ -164,8 +162,10 @@ def training_step(data): xm.mark_step() for name, layer in model.named_modules(): - if 'linear' in name: - xs.mark_sharding(layer.weight, mesh, parameter_sharding) + if 'EMB2FF_linear' in name: + xs.mark_sharding(layer.weight, mesh, ('tensor', 'fsdp')) + elif 'FF2EMB_linear' in name: + xs.mark_sharding(layer.weight, mesh, ('fsdp', 'tensor')) optimizer = torch.optim.SGD(model.parameters(), lr=0.1)