Skip to content

Commit

Permalink
reenable dynamo dynamic shape test (#7775)
Browse files Browse the repository at this point in the history
  • Loading branch information
JackCaoG committed Jul 30, 2024
1 parent 28de4eb commit b09c969
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 14 deletions.
3 changes: 3 additions & 0 deletions test/dynamo/test_dynamo_dynamic_shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,9 @@ def test_dynamic_shape_no_retracing(self):
# means we retrace the same fx multiple times.
self.assertNotIn('CachedCompile', met.counter_names())

@unittest.skip(
"Skip right now because with torch._dynamo.config.inline_inbuilt_nn_modules = True, dynamic compiles takes minutes for resnet18."
)
def test_dynamic_shape_resnet18(self):
device = torch_xla.device()

Expand Down
2 changes: 1 addition & 1 deletion test/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ function run_xla_op_tests1 {
run_test "$CDIR/dynamo/test_dynamo_integrations_util.py"
run_test "$CDIR/dynamo/test_dynamo_aliasing.py"
run_test "$CDIR/dynamo/test_dynamo.py"
# run_test "$CDIR/dynamo/test_dynamo_dynamic_shape.py"
run_test "$CDIR/dynamo/test_dynamo_dynamic_shape.py"
run_test "$CDIR/dynamo/test_bridge.py"
run_test "$CDIR/dynamo/test_num_output.py"
run_test "$CDIR/dynamo/test_graph_input_matcher.py"
Expand Down
2 changes: 1 addition & 1 deletion test/tpu/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ XLA_EXPERIMENTAL=nonzero:masked_select:nms python3 test/ds/test_dynamic_shapes.p
python3 test/test_autocast.py
python3 test/test_grad_checkpoint.py
python3 test/dynamo/test_dynamo.py
# python3 test/dynamo/test_dynamo_dynamic_shape.py
python3 test/dynamo/test_dynamo_dynamic_shape.py
python3 test/spmd/test_spmd_debugging.py
XLA_PARAMETER_WRAPPING_THREADSHOLD=1 python test/spmd/test_spmd_parameter_wrapping.py
python3 test/pjrt/test_dtypes.py
Expand Down
25 changes: 13 additions & 12 deletions torch_xla/core/dynamo_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import os
import time
from typing import Any, Dict, List, Set, Tuple
from numbers import Number
from contextlib import contextmanager

import torch
Expand Down Expand Up @@ -224,17 +225,17 @@ class SpecialReturnHandler:
XLA will dedup those duplicate items, but we need recover the duplications to maintain
the contract with the caller.
2. Int output from dynamic compile
In the case of the `torch.compile(Dynamic=True)` there might be some int outputs related
to the dynmiac dimension of the tensor. These ints are static for a given input shape
2. constant output from dynamic compile
In the case of the `torch.compile(Dynamic=True)` there might be some int or float outputs related
to the dynmiac dimension of the tensor. These numbers are static for a given input shape
combinations so we can cache them and inject to the final result directly.
"""

def __init__(self, trace_inputs, trace_outputs,
trace_inputs_inplace_update_bool, int_outputs_and_indexes):
trace_inputs_inplace_update_bool, constant_outputs_and_indexes):
self.trace_inputs = trace_inputs
self.trace_outputs = trace_outputs
self.int_outputs_and_indexes = int_outputs_and_indexes
self.constant_outputs_and_indexes = constant_outputs_and_indexes

# dedup the traced outputs first
self.deduper = Deduper()
Expand All @@ -260,9 +261,9 @@ def addDumbReturn(self, real_inputs, real_outputs):

ret = self.deduper.recover(real_outputs)

if len(self.int_outputs_and_indexes) != 0:
if len(self.constant_outputs_and_indexes) != 0:
# insert the int outputs back to the res
for index, value in self.int_outputs_and_indexes:
for index, value in self.constant_outputs_and_indexes:
ret.insert(index, value)

return ret
Expand Down Expand Up @@ -386,22 +387,22 @@ def extract_graph_helper(xla_model: torch.fx.GraphModule,

args_and_out = tuple(xla_args_need_update) + tuple(xla_out)
# args_and_out should be tensor only, in the dynamic cases there might be
# symint return as the result. In that case we want to extract them and separate
# symint or symfloat return as the result. In that case we want to extract them and separate
# them from the device computation.
int_outputs_and_indexes = []
constant_outputs_and_indexes = []
args_and_out_tensor_only = []
for i in range(len(args_and_out)):
arg = args_and_out[i]
if not isinstance(arg, torch.Tensor):
assert type(arg) == int
int_outputs_and_indexes.append((i, arg))
assert isinstance(arg, Number)
constant_outputs_and_indexes.append((i, arg))
else:
args_and_out_tensor_only.append(arg)

special_return_handler = SpecialReturnHandler(xla_args,
args_and_out_tensor_only,
xla_args_need_update_bool,
int_outputs_and_indexes)
constant_outputs_and_indexes)

# There is a `mark_step` in the beginning of this function call, we need to wait
# for that to finish before retriving the device data nodes.
Expand Down

0 comments on commit b09c969

Please sign in to comment.