From 858359e02ac5cd10cd22a2c3ad02155599b2e071 Mon Sep 17 00:00:00 2001 From: Wonjoo Lee Date: Fri, 19 Jul 2024 21:43:33 +0000 Subject: [PATCH] Clean up some code --- torch_xla/core/dynamo_bridge.py | 45 ++++++++++++++++++--------------- 1 file changed, 24 insertions(+), 21 deletions(-) diff --git a/torch_xla/core/dynamo_bridge.py b/torch_xla/core/dynamo_bridge.py index 3ae84248724c..4e094e5f7753 100644 --- a/torch_xla/core/dynamo_bridge.py +++ b/torch_xla/core/dynamo_bridge.py @@ -171,6 +171,15 @@ def _maybe_move_tensors_to_device(tensors: tuple, return tuple(moved_tensors) +# Given a list of args, returns a tuple of all the args' shapes. +def _get_arg_input_shapes(args): + arg_input_shapes = [] + for arg in args: + if isinstance(arg, torch.Tensor): + arg_input_shapes.append(tuple(arg.shape)) + return tuple(arg_input_shapes) + + class Deduper: def __init__(self): @@ -445,6 +454,19 @@ def extract_internal(xla_model: torch.fx.GraphModule): (xla_args_sharding_spec, args_and_out, graph_hash, arg_index_to_need_update_index, none_remover, graph_input_matcher, dumb_return_handler, xla_args_need_update) = extract_graph_helper(xla_model) + + # If dynamic=True, we cache the result to avoid recompilation. + if not torch._dynamo.config.assume_static_by_default: + arg_input_shapes = _get_arg_input_shapes(xla_model.xla_args) + print( + f'[WONJOO] extract_internal before optimized_mod, {arg_input_shapes=}') + input_shape_mappings[arg_input_shapes] = (xla_args_sharding_spec, + args_and_out, graph_hash, + arg_index_to_need_update_index, + none_remover, graph_input_matcher, + dumb_return_handler, + xla_args_need_update) + skip_checking_input_sharding_threashold = xu.getenv_as( 'XLA_DYNAMO_INPUT_SHARDING_CHECK_THRESHOLD', int, 5) @@ -463,34 +485,15 @@ def optimized_mod(*args: tuple): # See [Note: Dynamo real-time input-shape cache look-up] above. if not torch._dynamo.config.assume_static_by_default: - # TODO figure out why we need this mark_step or sync here. - # Without it, the results are off slightly. - xm.mark_step() - # input_tensors_to_sync = [ - # args[i] for i, x in enumerate( - # torch_xla._XLAC._check_tensor_need_materialization( - # [a for a in args if isinstance(a, torch.Tensor)])) if x - # ] - # if len(input_tensors_to_sync) > 0: - # torch_xla._XLAC._xla_increment_counter('DynamoSyncInputExecuteTime', 1) - # torch_xla._XLAC._xla_sync_multi( - # input_tensors_to_sync, devices=[], wait=True, sync_xla_data=True) - - arg_input_shapes = [] - # When dynamic=True in torch.compile call, TorchDynamo will directly - # call optimized_mod without compiling. Hence, xla_model's xla_args will - # be different from optimized_mod's args. So we manually set them here. xla_model.xla_args = args - for arg in args: - if isinstance(arg, torch.Tensor): - arg_input_shapes.append(tuple(arg.shape)) - arg_input_shapes = tuple(arg_input_shapes) + arg_input_shapes = _get_arg_input_shapes(xla_model.xla_args) if arg_input_shapes in input_shape_mappings: (xla_args_sharding_spec, args_and_out, graph_hash, arg_index_to_need_update_index, none_remover, graph_input_matcher, dumb_return_handler, xla_args_need_update) = input_shape_mappings[arg_input_shapes] else: + xm.mark_step() (xla_args_sharding_spec, args_and_out, graph_hash, arg_index_to_need_update_index, none_remover, graph_input_matcher, dumb_return_handler,