Skip to content

Commit

Permalink
Update tests
Browse files Browse the repository at this point in the history
  • Loading branch information
wonjoolee95 committed Jul 16, 2024
1 parent d0b6440 commit f2ef03f
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 48 deletions.
70 changes: 38 additions & 32 deletions test/dynamo/test_dynamo.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,8 +333,8 @@ def test_simple_model_with_different_input_shape(self, initialize_on_cuda):
rtol=1e-05,
atol=1e-05))

def get_loader(self, device, sample_count):
batch_size = xu.getenv_as('BATCH_SIZE', int, defval=4)
def get_loader(self, device, sample_count, batch_size=4):
batch_size = xu.getenv_as('BATCH_SIZE', int, batch_size)
loader = xu.SampleGenerator(
data=(torch.randn(batch_size, 3, 224, 224, device=device),
torch.zeros(batch_size, dtype=torch.int64, device=device)),
Expand All @@ -360,8 +360,9 @@ def test_resnet18(self, initialize_on_cuda):
xm.mark_step()
xm.wait_device_ops()
met.clear_all()
dynamo_resnet18 = torch.compile(device_resnet18, backend='openxla')
for data, _ in loader:
dynamo_resnet18 = torch.compile(device_resnet18, backend='openxla')
# dynamo_resnet18 = torch.compile(device_resnet18, backend='openxla')
output = dynamo_resnet18(data)
output_cpu = resnet18(data.cpu())
self.assertTrue(
Expand All @@ -374,6 +375,40 @@ def test_resnet18(self, initialize_on_cuda):
self.assertEqual(
met.metric_data('RunCachedGraphOutputData')[0], sample_count)

@skipOnTpu
@parameterized.parameters(
True,
False,
)
def test_dynamic_shape_resnet18(self, initialize_on_cuda):
device = self._choose_proper_device(initialize_on_cuda)
sample_count = xu.getenv_as('SAMPLE_COUNT', int, defval=10)
loader = self.get_loader(device, sample_count, batch_size=4)
resnet18 = torchvision.models.resnet18()
resnet18.eval()
device_resnet18 = torchvision.models.resnet18()
device_resnet18.load_state_dict(resnet18.state_dict())
device_resnet18.to(device)
device_resnet18.eval()
# materalize the fake data for test purpose
xm.mark_step()
xm.wait_device_ops()
met.clear_all()
dynamo_resnet18 = torch.compile(
device_resnet18, backend='openxla', dynamic=True)
for data, _ in loader:
output = dynamo_resnet18(data)
output_cpu = resnet18(data.cpu())
self.assertTrue(
torch.allclose(output_cpu, output.cpu(), rtol=1e-05, atol=1e-05))

loader_new_shape = self.get_loader(device, sample_count, batch_size=8)
for data, _ in loader:
output = dynamo_resnet18(data)
output_cpu = resnet18(data.cpu())
self.assertTrue(
torch.allclose(output_cpu, output.cpu(), rtol=1e-05, atol=1e-05))

def test_resnet18_lazy_vs_dynamo(self):
sample_count = xu.getenv_as('SAMPLE_COUNT', int, defval=10)
device = torch_xla.device()
Expand Down Expand Up @@ -769,35 +804,6 @@ def foo(x):

self.assertEqual(expected, actual.cpu())

class DynamoDynamicShapeTest(unittest.TestCase):

def simple_add(self, a, b):
c = a + b
return c

def test_dynamic_shape(self):
met.clear_all()
device = xm.xla_device()

compiled_fn = torch.compile(
self.simple_add, backend="openxla", fullgraph=True, dynamic=True)

a_cpu = torch.randn(3, 4)
a_xla = a_cpu.to(device)
b_cpu = torch.ones(4)
b_xla = b_cpu.to(device)
res_cpu = compiled_fn(a_cpu, b_cpu)
res_xla = compiled_fn(a_xla, b_xla)
self.assertTrue(torch.all(torch.eq(res_cpu, res_xla.cpu())))

c_cpu = torch.randn(5, 6)
c_xla = c_cpu.to(device)
d_cpu = torch.ones(6)
d_xla = d_cpu.to(device)
res_cpu_2 = compiled_fn(c_cpu, d_cpu)
res_xla_2 = compiled_fn(c_xla, d_xla)
self.assertTrue(torch.all(torch.eq(res_cpu_2, res_xla_2.cpu())))


if __name__ == '__main__':
test = unittest.main()
Expand Down
38 changes: 22 additions & 16 deletions torch_xla/core/dynamo_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,9 +433,9 @@ def extract_internal(xla_model: torch.fx.GraphModule):

# [Note: Dynamo real-time input-shape cache look-up]
# We maintain a mapping of input shapes to outputs of extract_graph_helper.
# When dynamic=True in torch.compile call, TorchDynamo will not trigger a
# When dynamic=True in torch.compile call, TorchDynamo will not trigger a
# new recompile. Then TorchXLA needs to figure out if these input shapes have
# been seen before.
# been seen before.
input_shape_mappings = {}

(xla_args_sharding_spec, args_and_out, graph_hash,
Expand All @@ -457,31 +457,34 @@ def optimized_mod(*args: tuple):
nonlocal xla_args_need_update
nonlocal skip_checking_input_sharding_threashold

# When dynamic=True in torch.compile call, TorchDynamo will directly
# call optimized_mod without compiling. Hence, xla_model's xla_args will
# When dynamic=True in torch.compile call, TorchDynamo will directly
# call optimized_mod without compiling. Hence, xla_model's xla_args will
# be different from optimized_mod's args. So we manually set them here.
xla_model.xla_args = args

# See [Note: Dynamo real-time input-shape cache look-up] above.
arg_input_shapes = []
for arg in args:
arg_input_shapes.append(tuple(arg.shape))
if isinstance(arg, torch.Tensor):
arg_input_shapes.append(tuple(arg.shape))
arg_input_shapes = tuple(arg_input_shapes)
if arg_input_shapes in input_shape_mappings:
(xla_args_sharding_spec, args_and_out, graph_hash,
arg_index_to_need_update_index, none_remover, graph_input_matcher,
dumb_return_handler,
xla_args_need_update) = input_shape_mappings[arg_input_shapes]
arg_index_to_need_update_index, none_remover, graph_input_matcher,
dumb_return_handler,
xla_args_need_update) = input_shape_mappings[arg_input_shapes]
else:
(xla_args_sharding_spec, args_and_out, graph_hash,
arg_index_to_need_update_index, none_remover, graph_input_matcher,
dumb_return_handler,
xla_args_need_update) = extract_graph_helper(xla_model)
input_shape_mappings[arg_input_shapes] = (xla_args_sharding_spec,
arg_index_to_need_update_index, none_remover, graph_input_matcher,
dumb_return_handler,
xla_args_need_update) = extract_graph_helper(xla_model)
input_shape_mappings[arg_input_shapes] = (xla_args_sharding_spec,
args_and_out, graph_hash,
arg_index_to_need_update_index,
none_remover, graph_input_matcher,
dumb_return_handler, xla_args_need_update)
arg_index_to_need_update_index,
none_remover,
graph_input_matcher,
dumb_return_handler,
xla_args_need_update)

original_device: torch.device = _get_input_arg_device(args)
is_cuda_args: bool = False
Expand Down Expand Up @@ -645,6 +648,7 @@ def extract_compiled_graph(xla_model: torch.fx.GraphModule, xla_args):
with torch_xla.experimental.eager_mode_context(False):
return extract_compiled_graph_helper(xla_model, xla_args)


def extract_compiled_graph_helper(xla_model: torch.fx.GraphModule, xla_args):
if _args_on_cuda(xla_args):
xla_args = tuple(_maybe_move_tensors_to_device(xla_args, xm.xla_device()))
Expand Down Expand Up @@ -673,7 +677,9 @@ def extract_compiled_graph_helper(xla_model: torch.fx.GraphModule, xla_args):
self_args.append(buffer)

# When dynamic_shape=True, TorchDynamo will pass us shapes as integers. We want to deal with the tensors only for now, so keep them separately.
all_xla_args = [xla_arg for xla_arg in xla_args if isinstance(xla_arg, torch.Tensor)] + self_args
all_xla_args = [
xla_arg for xla_arg in xla_args if isinstance(xla_arg, torch.Tensor)
] + self_args

for xla_arg in xla_args:
if isinstance(xla_arg, torch.Tensor) and xla_arg.device.type != 'xla':
Expand Down

0 comments on commit f2ef03f

Please sign in to comment.