Skip to content

Commit

Permalink
Remove use_dynamo_custom_op flag from unit tests (#7287)
Browse files Browse the repository at this point in the history
  • Loading branch information
wonjoolee95 committed Jun 15, 2024
1 parent cc55cc4 commit 3f22daa
Showing 1 changed file with 1 addition and 48 deletions.
49 changes: 1 addition & 48 deletions test/spmd/test_dynamo_spmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,7 @@ def __init__(self, mesh=None):

def forward(self, x):
if self.mesh and 'xla' in str(self.fc2.weight.device):
xs.mark_sharding(
self.fc2.weight, self.mesh, (1, 0), use_dynamo_custom_op=True)
xs.mark_sharding(self.fc2.weight, self.mesh, (1, 0))
y = self.relu(self.fc1(x))
z = self.fc2(y)
return self.fc3(z)
Expand Down Expand Up @@ -184,52 +183,6 @@ def test_dynamo_input_sharding_threashold(self):
else:
del os.environ['XLA_DYNAMO_INPUT_SHARDING_CHECK_THRESHOLD']

def test_dynamo_spmd_mark_sharding_outside_of_compile(self):
device = xm.xla_device()
linear = SimpleLinear().to(device)
linear.eval()
xla_x = torch.randn(1, 128, device=device)
xs.mark_sharding(
linear.fc2.weight,
self._get_mesh((1, self.n_devices)), (1, 0),
use_dynamo_custom_op=True)
xla_res = linear(xla_x)
xm.mark_step()

dynamo_linear = torch.compile(linear, backend="openxla")
dynamo_res = dynamo_linear(xla_x)
torch.allclose(xla_res.cpu(), dynamo_res.cpu())

# Ensure that another run with same input does not trigger additional compilation
compile_count = met.metric_data('CompileTime')[0]
dynamo_res = dynamo_linear(xla_x)
self.assertEqual(met.metric_data('CompileTime')[0], compile_count)

# https://github.com/pytorch/xla/pull/6921#issuecomment-2062106737
@unittest.skip("Failing in CI")
def test_mark_sharding_inside_compile(self):
met.clear_counters()
device = xm.xla_device()
mesh = self._get_mesh((1, self.n_devices))

# Passing this `mesh` as a parameter to `SimpleLinear` will call the dynamo custom op
# variant of mark_sharding inside the forward function.
linear = SimpleLinear(mesh=mesh).to(device)
linear.eval()

xla_x = torch.randn(1, 128, device=device)
xla_res = linear(xla_x)
xm.mark_step()

dynamo_linear = torch.compile(linear, backend="openxla")
dynamo_res = dynamo_linear(xla_x)
torch.allclose(xla_res.cpu(), dynamo_res.cpu())

# Ensure that another run with same input does not trigger additional compilation
compile_count = met.metric_data('CompileTime')[0]
dynamo_res = dynamo_linear(xla_x)
self.assertEqual(met.metric_data('CompileTime')[0], compile_count)

def test_dynamo_spmd_basic_with_dynamo_mark_sharding(self):
device = xm.xla_device()
linear = SimpleLinear().to(device)
Expand Down

0 comments on commit 3f22daa

Please sign in to comment.