-
Notifications
You must be signed in to change notification settings - Fork 467
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for dynamic shape in dynamo #7676
Changes from 8 commits
36b7e69
0175117
63be2ef
75de8d5
f266986
14947bd
e4b52d5
8b8897c
cdaefe4
bf0e640
5b8b67f
256c4ac
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -333,8 +333,8 @@ def test_simple_model_with_different_input_shape(self, initialize_on_cuda): | |
rtol=1e-05, | ||
atol=1e-05)) | ||
|
||
def get_loader(self, device, sample_count): | ||
batch_size = xu.getenv_as('BATCH_SIZE', int, defval=4) | ||
def get_loader(self, device, sample_count, batch_size=4): | ||
batch_size = xu.getenv_as('BATCH_SIZE', int, defval=batch_size) | ||
loader = xu.SampleGenerator( | ||
data=(torch.randn(batch_size, 3, 224, 224, device=device), | ||
torch.zeros(batch_size, dtype=torch.int64, device=device)), | ||
|
@@ -349,7 +349,7 @@ def get_loader(self, device, sample_count): | |
def test_resnet18(self, initialize_on_cuda): | ||
device = self._choose_proper_device(initialize_on_cuda) | ||
sample_count = xu.getenv_as('SAMPLE_COUNT', int, defval=10) | ||
loader = self.get_loader(device, sample_count) | ||
loader = self.get_loader(device, sample_count, batch_size=4) | ||
resnet18 = torchvision.models.resnet18() | ||
resnet18.eval() | ||
device_resnet18 = torchvision.models.resnet18() | ||
|
@@ -360,8 +360,8 @@ def test_resnet18(self, initialize_on_cuda): | |
xm.mark_step() | ||
xm.wait_device_ops() | ||
met.clear_all() | ||
dynamo_resnet18 = torch.compile(device_resnet18, backend='openxla') | ||
for data, _ in loader: | ||
dynamo_resnet18 = torch.compile(device_resnet18, backend='openxla') | ||
output = dynamo_resnet18(data) | ||
output_cpu = resnet18(data.cpu()) | ||
self.assertTrue( | ||
|
@@ -374,6 +374,54 @@ def test_resnet18(self, initialize_on_cuda): | |
self.assertEqual( | ||
met.metric_data('RunCachedGraphOutputData')[0], sample_count) | ||
|
||
@parameterized.parameters( | ||
True, | ||
False, | ||
) | ||
def test_dynamic_shape_resnet18(self, initialize_on_cuda): | ||
device = self._choose_proper_device(initialize_on_cuda) | ||
sample_count = xu.getenv_as('SAMPLE_COUNT', int, defval=10) | ||
loader = self.get_loader(device, sample_count, batch_size=4) | ||
resnet18 = torchvision.models.resnet18() | ||
resnet18.eval() | ||
device_resnet18 = torchvision.models.resnet18() | ||
device_resnet18.load_state_dict(resnet18.state_dict()) | ||
device_resnet18.to(device) | ||
device_resnet18.eval() | ||
# materalize the fake data for test purpose | ||
xm.mark_step() | ||
xm.wait_device_ops() | ||
met.clear_all() | ||
dynamo_resnet18 = torch.compile( | ||
device_resnet18, backend='openxla', dynamic=True) | ||
for data, _ in loader: | ||
output = dynamo_resnet18(data) | ||
output_cpu = resnet18(data.cpu()) | ||
# TPU has some precision issues, skipping allclose check | ||
if not _is_on_tpu(): | ||
self.assertTrue( | ||
torch.allclose(output_cpu, output.cpu(), rtol=1e-05, atol=1e-05)) | ||
|
||
previous_extract_compile_count = met.counter_value( | ||
'DynamoExtractCompiledGraph') | ||
|
||
loader_new_shape = self.get_loader(device, sample_count, batch_size=2) | ||
for data, _ in loader_new_shape: | ||
output_new_shape = dynamo_resnet18(data) | ||
output_cpu_new_shape = resnet18(data.cpu()) | ||
# # TPU has some precision issues, skipping allclose check | ||
if not _is_on_tpu(): | ||
self.assertTrue( | ||
torch.allclose( | ||
output_cpu_new_shape, | ||
output_new_shape.cpu(), | ||
rtol=1e-05, | ||
atol=1e-05)) | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. maybe also check the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. also can you make another test to test the case of
want to make sure we don't forgot the old shapes that's cached. |
||
self.assertEqual( | ||
met.counter_value('DynamoExtractCompiledGraph'), | ||
previous_extract_compile_count) | ||
|
||
def test_resnet18_lazy_vs_dynamo(self): | ||
sample_count = xu.getenv_as('SAMPLE_COUNT', int, defval=10) | ||
device = torch_xla.device() | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -171,6 +171,15 @@ def _maybe_move_tensors_to_device(tensors: tuple, | |
return tuple(moved_tensors) | ||
|
||
|
||
# Given a list of args, returns a tuple of all the args' shapes. | ||
def _get_arg_input_shapes(args): | ||
arg_input_shapes = [] | ||
for arg in args: | ||
if isinstance(arg, torch.Tensor): | ||
arg_input_shapes.append(tuple(arg.shape)) | ||
return tuple(arg_input_shapes) | ||
|
||
|
||
class Deduper: | ||
|
||
def __init__(self): | ||
|
@@ -430,9 +439,32 @@ def extract_internal(xla_model: torch.fx.GraphModule): | |
print(torch_xla._XLAC._get_xla_tensor_debug_info(xla_arg)) | ||
# Don't reset the scope as we might be under some profiler trace scope. | ||
xm.mark_step(reset_scope=False) | ||
|
||
# [Note: Dynamo real-time input-shape cache look-up] | ||
# We maintain a mapping of input shapes to outputs of extract_graph_helper. | ||
# When dynamic=True in torch.compile call, TorchDynamo will not trigger a | ||
# new recompile. Then TorchXLA needs to figure out if these input shapes have | ||
# been seen before. | ||
# Keys: tuple of input shapes | ||
# Values: tuple of (xla_args_sharding_spec, args_and_out, graph_hash, | ||
# arg_index_to_need_update_index, none_remover, graph_input_matcher, | ||
# dumb_return_handler, xla_args_need_update). | ||
input_shape_mappings: dict[tuple[int, ...], tuple[object, ...]] = {} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ust There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Updated |
||
|
||
(xla_args_sharding_spec, args_and_out, graph_hash, | ||
arg_index_to_need_update_index, none_remover, graph_input_matcher, | ||
dumb_return_handler, xla_args_need_update) = extract_graph_helper(xla_model) | ||
|
||
# If dynamic=True, we cache the result to avoid recompilation. | ||
if not torch._dynamo.config.assume_static_by_default: | ||
arg_input_shapes = _get_arg_input_shapes(xla_model.xla_args) | ||
input_shape_mappings[arg_input_shapes] = (xla_args_sharding_spec, | ||
args_and_out, graph_hash, | ||
arg_index_to_need_update_index, | ||
none_remover, graph_input_matcher, | ||
dumb_return_handler, | ||
xla_args_need_update) | ||
|
||
skip_checking_input_sharding_threashold = xu.getenv_as( | ||
'XLA_DYNAMO_INPUT_SHARDING_CHECK_THRESHOLD', int, 5) | ||
|
||
|
@@ -447,6 +479,27 @@ def optimized_mod(*args: tuple): | |
nonlocal dumb_return_handler | ||
nonlocal xla_args_need_update | ||
nonlocal skip_checking_input_sharding_threashold | ||
nonlocal input_shape_mappings | ||
|
||
# See [Note: Dynamo real-time input-shape cache look-up] above. | ||
if not torch._dynamo.config.assume_static_by_default: | ||
wonjoolee95 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
xla_model.xla_args = args | ||
arg_input_shapes = _get_arg_input_shapes(xla_model.xla_args) | ||
if arg_input_shapes in input_shape_mappings: | ||
(xla_args_sharding_spec, args_and_out, graph_hash, | ||
arg_index_to_need_update_index, none_remover, graph_input_matcher, | ||
dumb_return_handler, | ||
xla_args_need_update) = input_shape_mappings[arg_input_shapes] | ||
else: | ||
xm.mark_step() | ||
JackCaoG marked this conversation as resolved.
Show resolved
Hide resolved
|
||
(xla_args_sharding_spec, args_and_out, graph_hash, | ||
arg_index_to_need_update_index, none_remover, graph_input_matcher, | ||
dumb_return_handler, | ||
xla_args_need_update) = extract_graph_helper(xla_model) | ||
input_shape_mappings[arg_input_shapes] = ( | ||
xla_args_sharding_spec, args_and_out, graph_hash, | ||
arg_index_to_need_update_index, none_remover, graph_input_matcher, | ||
dumb_return_handler, xla_args_need_update) | ||
Comment on lines
+489
to
+492
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think you don't need this here There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. IIUC, we actually need this here. And we actually don't need this same logic in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ok then you will run into the same old problem right?
in this case you do the compile, but you do not cache the when There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. you should just do the caching(input_shape_mappings[arg_input_shapes] =) inside the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. let me fix this too.. |
||
|
||
original_device: torch.device = _get_input_arg_device(args) | ||
is_cuda_args: bool = False | ||
|
@@ -463,6 +516,7 @@ def optimized_mod(*args: tuple): | |
torch_xla._XLAC._check_tensor_need_materialization( | ||
[a for a in args if isinstance(a, torch.Tensor)])) if x | ||
] | ||
|
||
if len(input_tensors_to_sync) > 0: | ||
torch_xla._XLAC._xla_increment_counter('DynamoSyncInputExecuteTime', 1) | ||
torch_xla._XLAC._xla_sync_multi( | ||
|
@@ -479,7 +533,7 @@ def optimized_mod(*args: tuple): | |
args) != xla_args_sharding_spec: | ||
# update the xla_args with the input with new sharding and retrace | ||
xla_model.xla_args = args | ||
(xla_args_sharding_spec, args_and_ou_copy, graph_hash, | ||
(xla_args_sharding_spec, args_and_out_copy, graph_hash, | ||
arg_index_to_need_update_index, none_remover, graph_input_matcher, | ||
dumb_return_handler, | ||
xla_args_need_update) = extract_graph_helper(xla_model) | ||
|
@@ -529,6 +583,10 @@ def __init__(self, module): | |
self._unsupported_nodes = [] | ||
|
||
def run_node(self, n: torch.fx.Node): | ||
# We need to restore this metric count later, so save it in a separate variable | ||
dynamo_extract_graph_helper_metric_count = metrics.counter_value( | ||
'DynamoExtractCompiledGraph') | ||
Comment on lines
+577
to
+578
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. will There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's hard to see from documentations. However, when I try comparing metrics before/after There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ok then I am confused what this There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This code ( There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah I see, I can fix it later. I think the right thing to do is to define a region where counter does not incremented. |
||
|
||
metrics.clear_counters() | ||
result = super().run_node(n) | ||
fallback_ops = get_fallback_ops() | ||
|
@@ -563,6 +621,10 @@ def all_tensors_on_xla_device(value): | |
if not (result_is_supported and args_are_supported): | ||
self._unsupported_nodes.append(n) | ||
|
||
# Restore this metric counter | ||
torch_xla._XLAC._xla_increment_counter( | ||
'DynamoExtractCompiledGraph', dynamo_extract_graph_helper_metric_count) | ||
|
||
return result | ||
|
||
def get_unsupported_nodes(self): | ||
|
@@ -605,7 +667,8 @@ def allow_cpu_device(self, node: torch.fx.Node): | |
|
||
|
||
def extract_compiled_graph(xla_model: torch.fx.GraphModule, xla_args): | ||
# graph extraction must happens under tracing mode | ||
torch_xla._XLAC._xla_increment_counter('DynamoExtractCompiledGraph', 1) | ||
|
||
with torch_xla.experimental.eager_mode_context(False): | ||
return extract_compiled_graph_helper(xla_model, xla_args) | ||
|
||
|
@@ -636,10 +699,14 @@ def extract_compiled_graph_helper(xla_model: torch.fx.GraphModule, xla_args): | |
for name, buffer in xla_model.named_buffers(): | ||
if "self" in name: | ||
self_args.append(buffer) | ||
all_xla_args = list(xla_args) + self_args | ||
|
||
# When dynamic_shape=True, TorchDynamo will pass us shapes as integers. We want to deal with the tensors only for now, so keep them separately. | ||
all_xla_args = [ | ||
xla_arg for xla_arg in xla_args if isinstance(xla_arg, torch.Tensor) | ||
] + self_args | ||
|
||
for xla_arg in xla_args: | ||
if xla_arg.device.type != 'xla': | ||
if isinstance(xla_arg, torch.Tensor) and xla_arg.device.type != 'xla': | ||
warnings.warn( | ||
"Found tensor with shape " + str(xla_arg.size()) + " on " + | ||
str(xla_arg.device) + | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove one #
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated