Skip to content

Commit

Permalink
Update way to get dynamic_shape flag
Browse files Browse the repository at this point in the history
  • Loading branch information
wonjoolee95 committed Jul 19, 2024
1 parent 07ad306 commit ae0e393
Showing 1 changed file with 3 additions and 12 deletions.
15 changes: 3 additions & 12 deletions torch_xla/core/dynamo_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,8 +419,7 @@ def extract_graph_helper(xla_model: torch.fx.GraphModule):
dumb_return_handler, xla_args_need_update)


def extract_internal(xla_model: torch.fx.GraphModule,
enable_dynamic_shape=False):
def extract_internal(xla_model: torch.fx.GraphModule):
if dynamo_debug:
print(
'\n=================== OpenXLA Dynamo Compile Debug Begin ==================='
Expand Down Expand Up @@ -461,7 +460,7 @@ def optimized_mod(*args: tuple):
nonlocal skip_checking_input_sharding_threashold
nonlocal input_shape_mappings

if enable_dynamic_shape:
if not torch._dynamo.config.assume_static_by_default:
# See [Note: Dynamo real-time input-shape cache look-up] above.
arg_input_shapes = []
# When dynamic=True in torch.compile call, TorchDynamo will directly
Expand Down Expand Up @@ -663,12 +662,6 @@ def extract_compiled_graph_helper(xla_model: torch.fx.GraphModule, xla_args):
if _args_on_cuda(xla_args):
xla_args = tuple(_maybe_move_tensors_to_device(xla_args, xm.xla_device()))

enable_dynamic_shape = False
for a in xla_args:
if isinstance(a, int):
enable_dynamic_shape = True
break

# Synchronize xla_args, so that each FunctionalTensorWrapper argument updates its
# value reference before actually computing it.
for a in xla_args:
Expand Down Expand Up @@ -760,9 +753,7 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
partitioned_graph.delete_submodule(node.target)
with partitioned_graph.graph.inserting_after(node):
new_node = partitioned_graph.graph.call_function(
extract_internal(
fused_module, enable_dynamic_shape=enable_dynamic_shape),
node.args, None)
extract_internal(fused_module), node.args, None)
node.replace_all_uses_with(new_node)
partitioned_graph.graph.erase_node(node)

Expand Down

0 comments on commit ae0e393

Please sign in to comment.