diff --git a/test/spmd/test_dynamo_spmd.py b/test/spmd/test_dynamo_spmd.py index d1f6cdc3dce..eca6b5889f0 100644 --- a/test/spmd/test_dynamo_spmd.py +++ b/test/spmd/test_dynamo_spmd.py @@ -30,8 +30,7 @@ def __init__(self, mesh=None): def forward(self, x): if self.mesh and 'xla' in str(self.fc2.weight.device): - xs.mark_sharding( - self.fc2.weight, self.mesh, (1, 0), use_dynamo_custom_op=True) + xs.mark_sharding(self.fc2.weight, self.mesh, (1, 0)) y = self.relu(self.fc1(x)) z = self.fc2(y) return self.fc3(z) @@ -184,52 +183,6 @@ def test_dynamo_input_sharding_threashold(self): else: del os.environ['XLA_DYNAMO_INPUT_SHARDING_CHECK_THRESHOLD'] - def test_dynamo_spmd_mark_sharding_outside_of_compile(self): - device = xm.xla_device() - linear = SimpleLinear().to(device) - linear.eval() - xla_x = torch.randn(1, 128, device=device) - xs.mark_sharding( - linear.fc2.weight, - self._get_mesh((1, self.n_devices)), (1, 0), - use_dynamo_custom_op=True) - xla_res = linear(xla_x) - xm.mark_step() - - dynamo_linear = torch.compile(linear, backend="openxla") - dynamo_res = dynamo_linear(xla_x) - torch.allclose(xla_res.cpu(), dynamo_res.cpu()) - - # Ensure that another run with same input does not trigger additional compilation - compile_count = met.metric_data('CompileTime')[0] - dynamo_res = dynamo_linear(xla_x) - self.assertEqual(met.metric_data('CompileTime')[0], compile_count) - - # https://github.com/pytorch/xla/pull/6921#issuecomment-2062106737 - @unittest.skip("Failing in CI") - def test_mark_sharding_inside_compile(self): - met.clear_counters() - device = xm.xla_device() - mesh = self._get_mesh((1, self.n_devices)) - - # Passing this `mesh` as a parameter to `SimpleLinear` will call the dynamo custom op - # variant of mark_sharding inside the forward function. - linear = SimpleLinear(mesh=mesh).to(device) - linear.eval() - - xla_x = torch.randn(1, 128, device=device) - xla_res = linear(xla_x) - xm.mark_step() - - dynamo_linear = torch.compile(linear, backend="openxla") - dynamo_res = dynamo_linear(xla_x) - torch.allclose(xla_res.cpu(), dynamo_res.cpu()) - - # Ensure that another run with same input does not trigger additional compilation - compile_count = met.metric_data('CompileTime')[0] - dynamo_res = dynamo_linear(xla_x) - self.assertEqual(met.metric_data('CompileTime')[0], compile_count) - def test_dynamo_spmd_basic_with_dynamo_mark_sharding(self): device = xm.xla_device() linear = SimpleLinear().to(device)