Skip to content

Commit

Permalink
Fix regressing unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
wonjoolee95 committed Jul 18, 2024
1 parent bc32aa9 commit e8334d4
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 11 deletions.
2 changes: 0 additions & 2 deletions test/dynamo/test_dynamo.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,6 @@ def test_resnet18(self, initialize_on_cuda):
met.clear_all()
dynamo_resnet18 = torch.compile(device_resnet18, backend='openxla')
for data, _ in loader:
# dynamo_resnet18 = torch.compile(device_resnet18, backend='openxla')
output = dynamo_resnet18(data)
output_cpu = resnet18(data.cpu())
self.assertTrue(
Expand Down Expand Up @@ -408,7 +407,6 @@ def test_dynamic_shape_resnet18(self, initialize_on_cuda):
'DynamoExtractCompiledGraph')

loader_new_shape = self.get_loader(device, sample_count, batch_size=2)
diffs = []
for data, _ in loader_new_shape:
output_new_shape = dynamo_resnet18(data)
output_cpu_new_shape = resnet18(data.cpu())
Expand Down
18 changes: 9 additions & 9 deletions torch_xla/core/dynamo_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,8 +435,9 @@ def extract_internal(xla_model: torch.fx.GraphModule):
# We maintain a mapping of input shapes to outputs of extract_graph_helper.
# When dynamic=True in torch.compile call, TorchDynamo will not trigger a
# new recompile. Then TorchXLA needs to figure out if these input shapes have
# been seen before. The keys are tuple of input shapes, and values are a tuple
# of (xla_args_sharding_spec, args_and_out, graph_hash,
# been seen before.
# Keys: tuple of input shapes
# Values: tuple of (xla_args_sharding_spec, args_and_out, graph_hash,
# arg_index_to_need_update_index, none_remover, graph_input_matcher,
# dumb_return_handler, xla_args_need_update).
input_shape_mappings: dict[tuple[int, ...], tuple[object, ...]] = {}
Expand Down Expand Up @@ -557,7 +558,6 @@ def optimized_mod(*args: tuple):
print(
'=================== OpenXLA Dynamo Compile Debug End =====================\n'
)

return optimized_mod


Expand All @@ -568,10 +568,9 @@ def __init__(self, module):
self._unsupported_nodes = []

def run_node(self, n: torch.fx.Node):
original_metrics = {
counter_name: metrics.counter_value(counter_name)
for counter_name in metrics.counter_names()
}
# We need to restore this metric count later, so save it in a separate variable
dynamo_extract_graph_helper_metric_count = metrics.counter_value(
'DynamoExtractCompiledGraph')

metrics.clear_counters()
result = super().run_node(n)
Expand Down Expand Up @@ -607,8 +606,9 @@ def all_tensors_on_xla_device(value):
if not (result_is_supported and args_are_supported):
self._unsupported_nodes.append(n)

for name, value in original_metrics.items():
torch_xla._XLAC._xla_increment_counter(name, value)
# Restore this metric counter
torch_xla._XLAC._xla_increment_counter(
'DynamoExtractCompiledGraph', dynamo_extract_graph_helper_metric_count)

return result

Expand Down

0 comments on commit e8334d4

Please sign in to comment.