diff --git a/test/dynamo/test_dynamo.py b/test/dynamo/test_dynamo.py index 6757003983e7..84cf0065f609 100644 --- a/test/dynamo/test_dynamo.py +++ b/test/dynamo/test_dynamo.py @@ -362,7 +362,6 @@ def test_resnet18(self, initialize_on_cuda): met.clear_all() dynamo_resnet18 = torch.compile(device_resnet18, backend='openxla') for data, _ in loader: - # dynamo_resnet18 = torch.compile(device_resnet18, backend='openxla') output = dynamo_resnet18(data) output_cpu = resnet18(data.cpu()) self.assertTrue( @@ -375,7 +374,6 @@ def test_resnet18(self, initialize_on_cuda): self.assertEqual( met.metric_data('RunCachedGraphOutputData')[0], sample_count) - @skipOnTpu @parameterized.parameters( True, False, @@ -408,14 +406,17 @@ def test_dynamic_shape_resnet18(self, initialize_on_cuda): 'DynamoExtractCompiledGraph') loader_new_shape = self.get_loader(device, sample_count, batch_size=2) - diffs = [] for data, _ in loader_new_shape: output_new_shape = dynamo_resnet18(data) output_cpu_new_shape = resnet18(data.cpu()) # # TPU has some precision issues, skipping allclose check - # if not _is_on_tpu(): - # self.assertTrue( - # torch.allclose(output_cpu_new_shape, output_new_shape.cpu(), rtol=1e-05, atol=1e-05)) + if not _is_on_tpu(): + self.assertTrue( + torch.allclose( + output_cpu_new_shape, + output_new_shape.cpu(), + rtol=1e-05, + atol=1e-05)) self.assertEqual( met.counter_value('DynamoExtractCompiledGraph'), diff --git a/torch_xla/core/dynamo_bridge.py b/torch_xla/core/dynamo_bridge.py index cf6662305096..3ae84248724c 100644 --- a/torch_xla/core/dynamo_bridge.py +++ b/torch_xla/core/dynamo_bridge.py @@ -435,8 +435,9 @@ def extract_internal(xla_model: torch.fx.GraphModule): # We maintain a mapping of input shapes to outputs of extract_graph_helper. # When dynamic=True in torch.compile call, TorchDynamo will not trigger a # new recompile. Then TorchXLA needs to figure out if these input shapes have - # been seen before. The keys are tuple of input shapes, and values are a tuple - # of (xla_args_sharding_spec, args_and_out, graph_hash, + # been seen before. + # Keys: tuple of input shapes + # Values: tuple of (xla_args_sharding_spec, args_and_out, graph_hash, # arg_index_to_need_update_index, none_remover, graph_input_matcher, # dumb_return_handler, xla_args_need_update). input_shape_mappings: dict[tuple[int, ...], tuple[object, ...]] = {} @@ -460,8 +461,21 @@ def optimized_mod(*args: tuple): nonlocal skip_checking_input_sharding_threashold nonlocal input_shape_mappings + # See [Note: Dynamo real-time input-shape cache look-up] above. if not torch._dynamo.config.assume_static_by_default: - # See [Note: Dynamo real-time input-shape cache look-up] above. + # TODO figure out why we need this mark_step or sync here. + # Without it, the results are off slightly. + xm.mark_step() + # input_tensors_to_sync = [ + # args[i] for i, x in enumerate( + # torch_xla._XLAC._check_tensor_need_materialization( + # [a for a in args if isinstance(a, torch.Tensor)])) if x + # ] + # if len(input_tensors_to_sync) > 0: + # torch_xla._XLAC._xla_increment_counter('DynamoSyncInputExecuteTime', 1) + # torch_xla._XLAC._xla_sync_multi( + # input_tensors_to_sync, devices=[], wait=True, sync_xla_data=True) + arg_input_shapes = [] # When dynamic=True in torch.compile call, TorchDynamo will directly # call optimized_mod without compiling. Hence, xla_model's xla_args will @@ -501,6 +515,7 @@ def optimized_mod(*args: tuple): torch_xla._XLAC._check_tensor_need_materialization( [a for a in args if isinstance(a, torch.Tensor)])) if x ] + if len(input_tensors_to_sync) > 0: torch_xla._XLAC._xla_increment_counter('DynamoSyncInputExecuteTime', 1) torch_xla._XLAC._xla_sync_multi( @@ -557,7 +572,6 @@ def optimized_mod(*args: tuple): print( '=================== OpenXLA Dynamo Compile Debug End =====================\n' ) - return optimized_mod @@ -568,10 +582,9 @@ def __init__(self, module): self._unsupported_nodes = [] def run_node(self, n: torch.fx.Node): - original_metrics = { - counter_name: metrics.counter_value(counter_name) - for counter_name in metrics.counter_names() - } + # We need to restore this metric count later, so save it in a separate variable + dynamo_extract_graph_helper_metric_count = metrics.counter_value( + 'DynamoExtractCompiledGraph') metrics.clear_counters() result = super().run_node(n) @@ -607,8 +620,9 @@ def all_tensors_on_xla_device(value): if not (result_is_supported and args_are_supported): self._unsupported_nodes.append(n) - for name, value in original_metrics.items(): - torch_xla._XLAC._xla_increment_counter(name, value) + # Restore this metric counter + torch_xla._XLAC._xla_increment_counter( + 'DynamoExtractCompiledGraph', dynamo_extract_graph_helper_metric_count) return result