From e8334d4aeb3a9e081bae49f9cfd5e0cd4f4da5ae Mon Sep 17 00:00:00 2001 From: Wonjoo Lee Date: Thu, 18 Jul 2024 05:24:40 +0000 Subject: [PATCH] Fix regressing unit tests --- test/dynamo/test_dynamo.py | 2 -- torch_xla/core/dynamo_bridge.py | 18 +++++++++--------- 2 files changed, 9 insertions(+), 11 deletions(-) diff --git a/test/dynamo/test_dynamo.py b/test/dynamo/test_dynamo.py index 6757003983e7..45a7293cc56a 100644 --- a/test/dynamo/test_dynamo.py +++ b/test/dynamo/test_dynamo.py @@ -362,7 +362,6 @@ def test_resnet18(self, initialize_on_cuda): met.clear_all() dynamo_resnet18 = torch.compile(device_resnet18, backend='openxla') for data, _ in loader: - # dynamo_resnet18 = torch.compile(device_resnet18, backend='openxla') output = dynamo_resnet18(data) output_cpu = resnet18(data.cpu()) self.assertTrue( @@ -408,7 +407,6 @@ def test_dynamic_shape_resnet18(self, initialize_on_cuda): 'DynamoExtractCompiledGraph') loader_new_shape = self.get_loader(device, sample_count, batch_size=2) - diffs = [] for data, _ in loader_new_shape: output_new_shape = dynamo_resnet18(data) output_cpu_new_shape = resnet18(data.cpu()) diff --git a/torch_xla/core/dynamo_bridge.py b/torch_xla/core/dynamo_bridge.py index cf6662305096..5851c50cedb5 100644 --- a/torch_xla/core/dynamo_bridge.py +++ b/torch_xla/core/dynamo_bridge.py @@ -435,8 +435,9 @@ def extract_internal(xla_model: torch.fx.GraphModule): # We maintain a mapping of input shapes to outputs of extract_graph_helper. # When dynamic=True in torch.compile call, TorchDynamo will not trigger a # new recompile. Then TorchXLA needs to figure out if these input shapes have - # been seen before. The keys are tuple of input shapes, and values are a tuple - # of (xla_args_sharding_spec, args_and_out, graph_hash, + # been seen before. + # Keys: tuple of input shapes + # Values: tuple of (xla_args_sharding_spec, args_and_out, graph_hash, # arg_index_to_need_update_index, none_remover, graph_input_matcher, # dumb_return_handler, xla_args_need_update). input_shape_mappings: dict[tuple[int, ...], tuple[object, ...]] = {} @@ -557,7 +558,6 @@ def optimized_mod(*args: tuple): print( '=================== OpenXLA Dynamo Compile Debug End =====================\n' ) - return optimized_mod @@ -568,10 +568,9 @@ def __init__(self, module): self._unsupported_nodes = [] def run_node(self, n: torch.fx.Node): - original_metrics = { - counter_name: metrics.counter_value(counter_name) - for counter_name in metrics.counter_names() - } + # We need to restore this metric count later, so save it in a separate variable + dynamo_extract_graph_helper_metric_count = metrics.counter_value( + 'DynamoExtractCompiledGraph') metrics.clear_counters() result = super().run_node(n) @@ -607,8 +606,9 @@ def all_tensors_on_xla_device(value): if not (result_is_supported and args_are_supported): self._unsupported_nodes.append(n) - for name, value in original_metrics.items(): - torch_xla._XLAC._xla_increment_counter(name, value) + # Restore this metric counter + torch_xla._XLAC._xla_increment_counter( + 'DynamoExtractCompiledGraph', dynamo_extract_graph_helper_metric_count) return result