Skip to content

Commit

Permalink
Introduce flag and update tests
Browse files Browse the repository at this point in the history
  • Loading branch information
wonjoolee95 committed Jul 18, 2024
1 parent f2ef03f commit e836697
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 45 deletions.
32 changes: 22 additions & 10 deletions test/dynamo/test_dynamo.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,7 @@ def test_simple_model_with_different_input_shape(self, initialize_on_cuda):
atol=1e-05))

def get_loader(self, device, sample_count, batch_size=4):
batch_size = xu.getenv_as('BATCH_SIZE', int, batch_size)
batch_size = xu.getenv_as('BATCH_SIZE', int, defval=batch_size)
loader = xu.SampleGenerator(
data=(torch.randn(batch_size, 3, 224, 224, device=device),
torch.zeros(batch_size, dtype=torch.int64, device=device)),
Expand All @@ -349,7 +349,7 @@ def get_loader(self, device, sample_count, batch_size=4):
def test_resnet18(self, initialize_on_cuda):
device = self._choose_proper_device(initialize_on_cuda)
sample_count = xu.getenv_as('SAMPLE_COUNT', int, defval=10)
loader = self.get_loader(device, sample_count)
loader = self.get_loader(device, sample_count, batch_size=4)
resnet18 = torchvision.models.resnet18()
resnet18.eval()
device_resnet18 = torchvision.models.resnet18()
Expand Down Expand Up @@ -399,15 +399,27 @@ def test_dynamic_shape_resnet18(self, initialize_on_cuda):
for data, _ in loader:
output = dynamo_resnet18(data)
output_cpu = resnet18(data.cpu())
self.assertTrue(
torch.allclose(output_cpu, output.cpu(), rtol=1e-05, atol=1e-05))
# TPU has some precision issues, skipping allclose check
if not _is_on_tpu():
self.assertTrue(
torch.allclose(output_cpu, output.cpu(), rtol=1e-05, atol=1e-05))

previous_extract_compile_count = met.counter_value(
'DynamoExtractCompiledGraph')

loader_new_shape = self.get_loader(device, sample_count, batch_size=2)
diffs = []
for data, _ in loader_new_shape:
output_new_shape = dynamo_resnet18(data)
output_cpu_new_shape = resnet18(data.cpu())
# # TPU has some precision issues, skipping allclose check
# if not _is_on_tpu():
# self.assertTrue(
# torch.allclose(output_cpu_new_shape, output_new_shape.cpu(), rtol=1e-05, atol=1e-05))

loader_new_shape = self.get_loader(device, sample_count, batch_size=8)
for data, _ in loader:
output = dynamo_resnet18(data)
output_cpu = resnet18(data.cpu())
self.assertTrue(
torch.allclose(output_cpu, output.cpu(), rtol=1e-05, atol=1e-05))
self.assertEqual(
met.counter_value('DynamoExtractCompiledGraph'),
previous_extract_compile_count)

def test_resnet18_lazy_vs_dynamo(self):
sample_count = xu.getenv_as('SAMPLE_COUNT', int, defval=10)
Expand Down
88 changes: 53 additions & 35 deletions torch_xla/core/dynamo_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,7 +419,8 @@ def extract_graph_helper(xla_model: torch.fx.GraphModule):
dumb_return_handler, xla_args_need_update)


def extract_internal(xla_model: torch.fx.GraphModule):
def extract_internal(xla_model: torch.fx.GraphModule,
enable_dynamic_shape=False):
if dynamo_debug:
print(
'\n=================== OpenXLA Dynamo Compile Debug Begin ==================='
Expand All @@ -435,8 +436,11 @@ def extract_internal(xla_model: torch.fx.GraphModule):
# We maintain a mapping of input shapes to outputs of extract_graph_helper.
# When dynamic=True in torch.compile call, TorchDynamo will not trigger a
# new recompile. Then TorchXLA needs to figure out if these input shapes have
# been seen before.
input_shape_mappings = {}
# been seen before. The keys are tuple of input shapes, and values are a tuple
# of (xla_args_sharding_spec, args_and_out, graph_hash,
# arg_index_to_need_update_index, none_remover, graph_input_matcher,
# dumb_return_handler, xla_args_need_update).
input_shape_mappings: dict[tuple[int, ...], tuple[object, ...]] = {}

(xla_args_sharding_spec, args_and_out, graph_hash,
arg_index_to_need_update_index, none_remover, graph_input_matcher,
Expand All @@ -445,7 +449,6 @@ def extract_internal(xla_model: torch.fx.GraphModule):
'XLA_DYNAMO_INPUT_SHARDING_CHECK_THRESHOLD', int, 5)

def optimized_mod(*args: tuple):
nonlocal input_shape_mappings
nonlocal xla_model
nonlocal xla_args_sharding_spec
nonlocal args_and_out
Expand All @@ -456,35 +459,33 @@ def optimized_mod(*args: tuple):
nonlocal dumb_return_handler
nonlocal xla_args_need_update
nonlocal skip_checking_input_sharding_threashold
nonlocal input_shape_mappings

# When dynamic=True in torch.compile call, TorchDynamo will directly
# call optimized_mod without compiling. Hence, xla_model's xla_args will
# be different from optimized_mod's args. So we manually set them here.
xla_model.xla_args = args

# See [Note: Dynamo real-time input-shape cache look-up] above.
arg_input_shapes = []
for arg in args:
if isinstance(arg, torch.Tensor):
arg_input_shapes.append(tuple(arg.shape))
arg_input_shapes = tuple(arg_input_shapes)
if arg_input_shapes in input_shape_mappings:
(xla_args_sharding_spec, args_and_out, graph_hash,
arg_index_to_need_update_index, none_remover, graph_input_matcher,
dumb_return_handler,
xla_args_need_update) = input_shape_mappings[arg_input_shapes]
else:
(xla_args_sharding_spec, args_and_out, graph_hash,
arg_index_to_need_update_index, none_remover, graph_input_matcher,
dumb_return_handler,
xla_args_need_update) = extract_graph_helper(xla_model)
input_shape_mappings[arg_input_shapes] = (xla_args_sharding_spec,
args_and_out, graph_hash,
arg_index_to_need_update_index,
none_remover,
graph_input_matcher,
dumb_return_handler,
xla_args_need_update)
if enable_dynamic_shape:
# See [Note: Dynamo real-time input-shape cache look-up] above.
arg_input_shapes = []
# When dynamic=True in torch.compile call, TorchDynamo will directly
# call optimized_mod without compiling. Hence, xla_model's xla_args will
# be different from optimized_mod's args. So we manually set them here.
xla_model.xla_args = args
for arg in args:
if isinstance(arg, torch.Tensor):
arg_input_shapes.append(tuple(arg.shape))
arg_input_shapes = tuple(arg_input_shapes)
if arg_input_shapes in input_shape_mappings:
(xla_args_sharding_spec, args_and_out, graph_hash,
arg_index_to_need_update_index, none_remover, graph_input_matcher,
dumb_return_handler,
xla_args_need_update) = input_shape_mappings[arg_input_shapes]
else:
(xla_args_sharding_spec, args_and_out, graph_hash,
arg_index_to_need_update_index, none_remover, graph_input_matcher,
dumb_return_handler,
xla_args_need_update) = extract_graph_helper(xla_model)
input_shape_mappings[arg_input_shapes] = (
xla_args_sharding_spec, args_and_out, graph_hash,
arg_index_to_need_update_index, none_remover, graph_input_matcher,
dumb_return_handler, xla_args_need_update)

original_device: torch.device = _get_input_arg_device(args)
is_cuda_args: bool = False
Expand Down Expand Up @@ -517,7 +518,7 @@ def optimized_mod(*args: tuple):
args) != xla_args_sharding_spec:
# update the xla_args with the input with new sharding and retrace
xla_model.xla_args = args
(xla_args_sharding_spec, args_and_ou_copy, graph_hash,
(xla_args_sharding_spec, args_and_out_copy, graph_hash,
arg_index_to_need_update_index, none_remover, graph_input_matcher,
dumb_return_handler,
xla_args_need_update) = extract_graph_helper(xla_model)
Expand Down Expand Up @@ -568,6 +569,11 @@ def __init__(self, module):
self._unsupported_nodes = []

def run_node(self, n: torch.fx.Node):
original_metrics = {
counter_name: metrics.counter_value(counter_name)
for counter_name in metrics.counter_names()
}

metrics.clear_counters()
result = super().run_node(n)
fallback_ops = get_fallback_ops()
Expand Down Expand Up @@ -602,6 +608,9 @@ def all_tensors_on_xla_device(value):
if not (result_is_supported and args_are_supported):
self._unsupported_nodes.append(n)

for name, value in original_metrics.items():
torch_xla._XLAC._xla_increment_counter(name, value)

return result

def get_unsupported_nodes(self):
Expand Down Expand Up @@ -644,7 +653,8 @@ def allow_cpu_device(self, node: torch.fx.Node):


def extract_compiled_graph(xla_model: torch.fx.GraphModule, xla_args):
# graph extraction must happens under tracing mode
torch_xla._XLAC._xla_increment_counter('DynamoExtractCompiledGraph', 1)

with torch_xla.experimental.eager_mode_context(False):
return extract_compiled_graph_helper(xla_model, xla_args)

Expand All @@ -653,6 +663,12 @@ def extract_compiled_graph_helper(xla_model: torch.fx.GraphModule, xla_args):
if _args_on_cuda(xla_args):
xla_args = tuple(_maybe_move_tensors_to_device(xla_args, xm.xla_device()))

enable_dynamic_shape = False
for a in xla_args:
if isinstance(a, int):
enable_dynamic_shape = True
break

# Synchronize xla_args, so that each FunctionalTensorWrapper argument updates its
# value reference before actually computing it.
for a in xla_args:
Expand Down Expand Up @@ -744,7 +760,9 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
partitioned_graph.delete_submodule(node.target)
with partitioned_graph.graph.inserting_after(node):
new_node = partitioned_graph.graph.call_function(
extract_internal(fused_module), node.args, None)
extract_internal(
fused_module, enable_dynamic_shape=enable_dynamic_shape),
node.args, None)
node.replace_all_uses_with(new_node)
partitioned_graph.graph.erase_node(node)

Expand Down

0 comments on commit e836697

Please sign in to comment.