diff --git a/test/spmd/test_xla_sharding.py b/test/spmd/test_xla_sharding.py index a3cdabe0564..1473fd5f995 100644 --- a/test/spmd/test_xla_sharding.py +++ b/test/spmd/test_xla_sharding.py @@ -1272,7 +1272,7 @@ def test_spmd_all_reduce(self): f"all-reduce(f32[8,8]{{1,0}} %custom-call.2), channel_id=1, replica_groups={{{{{','.join([str(x) for x in self.device_ids])}}}}}, use_global_device_ids=true, to_apply=%AddComputation.3", hlo) - expected_x = torch.ones(8, 8) * 4 + expected_x = torch.ones(8, 8) * self.n_devices self.assertTrue(torch.allclose(x.cpu(), expected_x)) @unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 4, @@ -1280,10 +1280,11 @@ def test_spmd_all_reduce(self): def test_spmd_all_reduce_scale(self): xs.set_global_mesh(self._get_mesh((1, self.n_devices))) x = torch.ones(8, 8).to(xm.xla_device()) + scale = 0.25 # all reduce x = xs.enable_manual_sharding(x, (None, None)).global_tensor - x = torch_xla._XLAC._xla_spmd_all_reduce(xm.REDUCE_SUM, x, 0.25, + x = torch_xla._XLAC._xla_spmd_all_reduce(xm.REDUCE_SUM, x, scale, [self.device_ids]) x = xs.disable_manual_sharding(x, (None, None), x.shape).global_tensor @@ -1292,7 +1293,7 @@ def test_spmd_all_reduce_scale(self): f"all-reduce(f32[8,8]{{1,0}} %custom-call.2), channel_id=1, replica_groups={{{{{','.join([str(x) for x in self.device_ids])}}}}}, use_global_device_ids=true, to_apply=%AddComputation.3", hlo) - expected_x = torch.ones(8, 8) + expected_x = torch.ones(8, 8) * int(self.n_devices * scale) self.assertTrue(torch.allclose(x.cpu(), expected_x)) def test_get_1d_mesh(self):