Skip to content

Commit

Permalink
Clean up some code
Browse files Browse the repository at this point in the history
  • Loading branch information
wonjoolee95 committed Jul 19, 2024
1 parent fcd08bb commit 858359e
Showing 1 changed file with 24 additions and 21 deletions.
45 changes: 24 additions & 21 deletions torch_xla/core/dynamo_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,15 @@ def _maybe_move_tensors_to_device(tensors: tuple,
return tuple(moved_tensors)


# Given a list of args, returns a tuple of all the args' shapes.
def _get_arg_input_shapes(args):
arg_input_shapes = []
for arg in args:
if isinstance(arg, torch.Tensor):
arg_input_shapes.append(tuple(arg.shape))
return tuple(arg_input_shapes)


class Deduper:

def __init__(self):
Expand Down Expand Up @@ -445,6 +454,19 @@ def extract_internal(xla_model: torch.fx.GraphModule):
(xla_args_sharding_spec, args_and_out, graph_hash,
arg_index_to_need_update_index, none_remover, graph_input_matcher,
dumb_return_handler, xla_args_need_update) = extract_graph_helper(xla_model)

# If dynamic=True, we cache the result to avoid recompilation.
if not torch._dynamo.config.assume_static_by_default:
arg_input_shapes = _get_arg_input_shapes(xla_model.xla_args)
print(
f'[WONJOO] extract_internal before optimized_mod, {arg_input_shapes=}')
input_shape_mappings[arg_input_shapes] = (xla_args_sharding_spec,
args_and_out, graph_hash,
arg_index_to_need_update_index,
none_remover, graph_input_matcher,
dumb_return_handler,
xla_args_need_update)

skip_checking_input_sharding_threashold = xu.getenv_as(
'XLA_DYNAMO_INPUT_SHARDING_CHECK_THRESHOLD', int, 5)

Expand All @@ -463,34 +485,15 @@ def optimized_mod(*args: tuple):

# See [Note: Dynamo real-time input-shape cache look-up] above.
if not torch._dynamo.config.assume_static_by_default:
# TODO figure out why we need this mark_step or sync here.
# Without it, the results are off slightly.
xm.mark_step()
# input_tensors_to_sync = [
# args[i] for i, x in enumerate(
# torch_xla._XLAC._check_tensor_need_materialization(
# [a for a in args if isinstance(a, torch.Tensor)])) if x
# ]
# if len(input_tensors_to_sync) > 0:
# torch_xla._XLAC._xla_increment_counter('DynamoSyncInputExecuteTime', 1)
# torch_xla._XLAC._xla_sync_multi(
# input_tensors_to_sync, devices=[], wait=True, sync_xla_data=True)

arg_input_shapes = []
# When dynamic=True in torch.compile call, TorchDynamo will directly
# call optimized_mod without compiling. Hence, xla_model's xla_args will
# be different from optimized_mod's args. So we manually set them here.
xla_model.xla_args = args
for arg in args:
if isinstance(arg, torch.Tensor):
arg_input_shapes.append(tuple(arg.shape))
arg_input_shapes = tuple(arg_input_shapes)
arg_input_shapes = _get_arg_input_shapes(xla_model.xla_args)
if arg_input_shapes in input_shape_mappings:
(xla_args_sharding_spec, args_and_out, graph_hash,
arg_index_to_need_update_index, none_remover, graph_input_matcher,
dumb_return_handler,
xla_args_need_update) = input_shape_mappings[arg_input_shapes]
else:
xm.mark_step()
(xla_args_sharding_spec, args_and_out, graph_hash,
arg_index_to_need_update_index, none_remover, graph_input_matcher,
dumb_return_handler,
Expand Down

0 comments on commit 858359e

Please sign in to comment.