Skip to content

Commit

Permalink
Fix regressing unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
wonjoolee95 committed Jul 19, 2024
1 parent ae0e393 commit fcd08bb
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 16 deletions.
13 changes: 7 additions & 6 deletions test/dynamo/test_dynamo.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,6 @@ def test_resnet18(self, initialize_on_cuda):
met.clear_all()
dynamo_resnet18 = torch.compile(device_resnet18, backend='openxla')
for data, _ in loader:
# dynamo_resnet18 = torch.compile(device_resnet18, backend='openxla')
output = dynamo_resnet18(data)
output_cpu = resnet18(data.cpu())
self.assertTrue(
Expand All @@ -375,7 +374,6 @@ def test_resnet18(self, initialize_on_cuda):
self.assertEqual(
met.metric_data('RunCachedGraphOutputData')[0], sample_count)

@skipOnTpu
@parameterized.parameters(
True,
False,
Expand Down Expand Up @@ -408,14 +406,17 @@ def test_dynamic_shape_resnet18(self, initialize_on_cuda):
'DynamoExtractCompiledGraph')

loader_new_shape = self.get_loader(device, sample_count, batch_size=2)
diffs = []
for data, _ in loader_new_shape:
output_new_shape = dynamo_resnet18(data)
output_cpu_new_shape = resnet18(data.cpu())
# # TPU has some precision issues, skipping allclose check
# if not _is_on_tpu():
# self.assertTrue(
# torch.allclose(output_cpu_new_shape, output_new_shape.cpu(), rtol=1e-05, atol=1e-05))
if not _is_on_tpu():
self.assertTrue(
torch.allclose(
output_cpu_new_shape,
output_new_shape.cpu(),
rtol=1e-05,
atol=1e-05))

self.assertEqual(
met.counter_value('DynamoExtractCompiledGraph'),
Expand Down
34 changes: 24 additions & 10 deletions torch_xla/core/dynamo_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,8 +435,9 @@ def extract_internal(xla_model: torch.fx.GraphModule):
# We maintain a mapping of input shapes to outputs of extract_graph_helper.
# When dynamic=True in torch.compile call, TorchDynamo will not trigger a
# new recompile. Then TorchXLA needs to figure out if these input shapes have
# been seen before. The keys are tuple of input shapes, and values are a tuple
# of (xla_args_sharding_spec, args_and_out, graph_hash,
# been seen before.
# Keys: tuple of input shapes
# Values: tuple of (xla_args_sharding_spec, args_and_out, graph_hash,
# arg_index_to_need_update_index, none_remover, graph_input_matcher,
# dumb_return_handler, xla_args_need_update).
input_shape_mappings: dict[tuple[int, ...], tuple[object, ...]] = {}
Expand All @@ -460,8 +461,21 @@ def optimized_mod(*args: tuple):
nonlocal skip_checking_input_sharding_threashold
nonlocal input_shape_mappings

# See [Note: Dynamo real-time input-shape cache look-up] above.
if not torch._dynamo.config.assume_static_by_default:
# See [Note: Dynamo real-time input-shape cache look-up] above.
# TODO figure out why we need this mark_step or sync here.
# Without it, the results are off slightly.
xm.mark_step()
# input_tensors_to_sync = [
# args[i] for i, x in enumerate(
# torch_xla._XLAC._check_tensor_need_materialization(
# [a for a in args if isinstance(a, torch.Tensor)])) if x
# ]
# if len(input_tensors_to_sync) > 0:
# torch_xla._XLAC._xla_increment_counter('DynamoSyncInputExecuteTime', 1)
# torch_xla._XLAC._xla_sync_multi(
# input_tensors_to_sync, devices=[], wait=True, sync_xla_data=True)

arg_input_shapes = []
# When dynamic=True in torch.compile call, TorchDynamo will directly
# call optimized_mod without compiling. Hence, xla_model's xla_args will
Expand Down Expand Up @@ -501,6 +515,7 @@ def optimized_mod(*args: tuple):
torch_xla._XLAC._check_tensor_need_materialization(
[a for a in args if isinstance(a, torch.Tensor)])) if x
]

if len(input_tensors_to_sync) > 0:
torch_xla._XLAC._xla_increment_counter('DynamoSyncInputExecuteTime', 1)
torch_xla._XLAC._xla_sync_multi(
Expand Down Expand Up @@ -557,7 +572,6 @@ def optimized_mod(*args: tuple):
print(
'=================== OpenXLA Dynamo Compile Debug End =====================\n'
)

return optimized_mod


Expand All @@ -568,10 +582,9 @@ def __init__(self, module):
self._unsupported_nodes = []

def run_node(self, n: torch.fx.Node):
original_metrics = {
counter_name: metrics.counter_value(counter_name)
for counter_name in metrics.counter_names()
}
# We need to restore this metric count later, so save it in a separate variable
dynamo_extract_graph_helper_metric_count = metrics.counter_value(
'DynamoExtractCompiledGraph')

metrics.clear_counters()
result = super().run_node(n)
Expand Down Expand Up @@ -607,8 +620,9 @@ def all_tensors_on_xla_device(value):
if not (result_is_supported and args_are_supported):
self._unsupported_nodes.append(n)

for name, value in original_metrics.items():
torch_xla._XLAC._xla_increment_counter(name, value)
# Restore this metric counter
torch_xla._XLAC._xla_increment_counter(
'DynamoExtractCompiledGraph', dynamo_extract_graph_helper_metric_count)

return result

Expand Down

0 comments on commit fcd08bb

Please sign in to comment.