Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for dynamic shape in dynamo #7676

Merged
merged 12 commits into from
Jul 23, 2024
56 changes: 52 additions & 4 deletions test/dynamo/test_dynamo.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,8 +333,8 @@ def test_simple_model_with_different_input_shape(self, initialize_on_cuda):
rtol=1e-05,
atol=1e-05))

def get_loader(self, device, sample_count):
batch_size = xu.getenv_as('BATCH_SIZE', int, defval=4)
def get_loader(self, device, sample_count, batch_size=4):
batch_size = xu.getenv_as('BATCH_SIZE', int, defval=batch_size)
loader = xu.SampleGenerator(
data=(torch.randn(batch_size, 3, 224, 224, device=device),
torch.zeros(batch_size, dtype=torch.int64, device=device)),
Expand All @@ -349,7 +349,7 @@ def get_loader(self, device, sample_count):
def test_resnet18(self, initialize_on_cuda):
device = self._choose_proper_device(initialize_on_cuda)
sample_count = xu.getenv_as('SAMPLE_COUNT', int, defval=10)
loader = self.get_loader(device, sample_count)
loader = self.get_loader(device, sample_count, batch_size=4)
resnet18 = torchvision.models.resnet18()
resnet18.eval()
device_resnet18 = torchvision.models.resnet18()
Expand All @@ -360,8 +360,8 @@ def test_resnet18(self, initialize_on_cuda):
xm.mark_step()
xm.wait_device_ops()
met.clear_all()
dynamo_resnet18 = torch.compile(device_resnet18, backend='openxla')
for data, _ in loader:
dynamo_resnet18 = torch.compile(device_resnet18, backend='openxla')
output = dynamo_resnet18(data)
output_cpu = resnet18(data.cpu())
self.assertTrue(
Expand All @@ -374,6 +374,54 @@ def test_resnet18(self, initialize_on_cuda):
self.assertEqual(
met.metric_data('RunCachedGraphOutputData')[0], sample_count)

@parameterized.parameters(
True,
False,
)
def test_dynamic_shape_resnet18(self, initialize_on_cuda):
device = self._choose_proper_device(initialize_on_cuda)
sample_count = xu.getenv_as('SAMPLE_COUNT', int, defval=10)
loader = self.get_loader(device, sample_count, batch_size=4)
resnet18 = torchvision.models.resnet18()
resnet18.eval()
device_resnet18 = torchvision.models.resnet18()
device_resnet18.load_state_dict(resnet18.state_dict())
device_resnet18.to(device)
device_resnet18.eval()
# materalize the fake data for test purpose
xm.mark_step()
xm.wait_device_ops()
met.clear_all()
dynamo_resnet18 = torch.compile(
device_resnet18, backend='openxla', dynamic=True)
for data, _ in loader:
output = dynamo_resnet18(data)
output_cpu = resnet18(data.cpu())
# TPU has some precision issues, skipping allclose check
if not _is_on_tpu():
self.assertTrue(
torch.allclose(output_cpu, output.cpu(), rtol=1e-05, atol=1e-05))

previous_extract_compile_count = met.counter_value(
'DynamoExtractCompiledGraph')

loader_new_shape = self.get_loader(device, sample_count, batch_size=2)
for data, _ in loader_new_shape:
output_new_shape = dynamo_resnet18(data)
output_cpu_new_shape = resnet18(data.cpu())
# # TPU has some precision issues, skipping allclose check
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove one #

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated

if not _is_on_tpu():
self.assertTrue(
torch.allclose(
output_cpu_new_shape,
output_new_shape.cpu(),
rtol=1e-05,
atol=1e-05))

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe also check the CompileTime and ExecuteTime here

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also can you make another test to test the case of

fn(shape_a)
fn(shape_b)
fn(shape_c)
fn(shape_a)

want to make sure we don't forgot the old shapes that's cached.

self.assertEqual(
met.counter_value('DynamoExtractCompiledGraph'),
previous_extract_compile_count)

def test_resnet18_lazy_vs_dynamo(self):
sample_count = xu.getenv_as('SAMPLE_COUNT', int, defval=10)
device = torch_xla.device()
Expand Down
75 changes: 71 additions & 4 deletions torch_xla/core/dynamo_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,15 @@ def _maybe_move_tensors_to_device(tensors: tuple,
return tuple(moved_tensors)


# Given a list of args, returns a tuple of all the args' shapes.
def _get_arg_input_shapes(args):
arg_input_shapes = []
for arg in args:
if isinstance(arg, torch.Tensor):
arg_input_shapes.append(tuple(arg.shape))
return tuple(arg_input_shapes)


class Deduper:

def __init__(self):
Expand Down Expand Up @@ -430,9 +439,32 @@ def extract_internal(xla_model: torch.fx.GraphModule):
print(torch_xla._XLAC._get_xla_tensor_debug_info(xla_arg))
# Don't reset the scope as we might be under some profiler trace scope.
xm.mark_step(reset_scope=False)

# [Note: Dynamo real-time input-shape cache look-up]
# We maintain a mapping of input shapes to outputs of extract_graph_helper.
# When dynamic=True in torch.compile call, TorchDynamo will not trigger a
# new recompile. Then TorchXLA needs to figure out if these input shapes have
# been seen before.
# Keys: tuple of input shapes
# Values: tuple of (xla_args_sharding_spec, args_and_out, graph_hash,
# arg_index_to_need_update_index, none_remover, graph_input_matcher,
# dumb_return_handler, xla_args_need_update).
input_shape_mappings: dict[tuple[int, ...], tuple[object, ...]] = {}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ust typing.Dict and typing.Tuple otherwise the python 3.8 CI in upstream will fail

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated


(xla_args_sharding_spec, args_and_out, graph_hash,
arg_index_to_need_update_index, none_remover, graph_input_matcher,
dumb_return_handler, xla_args_need_update) = extract_graph_helper(xla_model)

# If dynamic=True, we cache the result to avoid recompilation.
if not torch._dynamo.config.assume_static_by_default:
arg_input_shapes = _get_arg_input_shapes(xla_model.xla_args)
input_shape_mappings[arg_input_shapes] = (xla_args_sharding_spec,
args_and_out, graph_hash,
arg_index_to_need_update_index,
none_remover, graph_input_matcher,
dumb_return_handler,
xla_args_need_update)

skip_checking_input_sharding_threashold = xu.getenv_as(
'XLA_DYNAMO_INPUT_SHARDING_CHECK_THRESHOLD', int, 5)

Expand All @@ -447,6 +479,27 @@ def optimized_mod(*args: tuple):
nonlocal dumb_return_handler
nonlocal xla_args_need_update
nonlocal skip_checking_input_sharding_threashold
nonlocal input_shape_mappings

# See [Note: Dynamo real-time input-shape cache look-up] above.
if not torch._dynamo.config.assume_static_by_default:
wonjoolee95 marked this conversation as resolved.
Show resolved Hide resolved
xla_model.xla_args = args
arg_input_shapes = _get_arg_input_shapes(xla_model.xla_args)
if arg_input_shapes in input_shape_mappings:
(xla_args_sharding_spec, args_and_out, graph_hash,
arg_index_to_need_update_index, none_remover, graph_input_matcher,
dumb_return_handler,
xla_args_need_update) = input_shape_mappings[arg_input_shapes]
else:
xm.mark_step()
JackCaoG marked this conversation as resolved.
Show resolved Hide resolved
(xla_args_sharding_spec, args_and_out, graph_hash,
arg_index_to_need_update_index, none_remover, graph_input_matcher,
dumb_return_handler,
xla_args_need_update) = extract_graph_helper(xla_model)
input_shape_mappings[arg_input_shapes] = (
xla_args_sharding_spec, args_and_out, graph_hash,
arg_index_to_need_update_index, none_remover, graph_input_matcher,
dumb_return_handler, xla_args_need_update)
Comment on lines +489 to +492
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you don't need this here

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC, we actually need this here. And we actually don't need this same logic in extract_internal above (removed this in the newest commit). The reason is when dynamic=True, only optimized_mod is called. Other functions (including extract_internal) are not called.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok then you will run into the same old problem right?
first time

extract_graph_helper -> optimized_mod

in this case you do the compile, but you do not cache the input_shape_mappings

when optimized_mod is called the first tiem you will need to call extract_graph_helper again which is wasteful.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you should just do the caching(input_shape_mappings[arg_input_shapes] =) inside the extract_graph_helper

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let me fix this too..


original_device: torch.device = _get_input_arg_device(args)
is_cuda_args: bool = False
Expand All @@ -463,6 +516,7 @@ def optimized_mod(*args: tuple):
torch_xla._XLAC._check_tensor_need_materialization(
[a for a in args if isinstance(a, torch.Tensor)])) if x
]

if len(input_tensors_to_sync) > 0:
torch_xla._XLAC._xla_increment_counter('DynamoSyncInputExecuteTime', 1)
torch_xla._XLAC._xla_sync_multi(
Expand All @@ -479,7 +533,7 @@ def optimized_mod(*args: tuple):
args) != xla_args_sharding_spec:
# update the xla_args with the input with new sharding and retrace
xla_model.xla_args = args
(xla_args_sharding_spec, args_and_ou_copy, graph_hash,
(xla_args_sharding_spec, args_and_out_copy, graph_hash,
arg_index_to_need_update_index, none_remover, graph_input_matcher,
dumb_return_handler,
xla_args_need_update) = extract_graph_helper(xla_model)
Expand Down Expand Up @@ -529,6 +583,10 @@ def __init__(self, module):
self._unsupported_nodes = []

def run_node(self, n: torch.fx.Node):
# We need to restore this metric count later, so save it in a separate variable
dynamo_extract_graph_helper_metric_count = metrics.counter_value(
'DynamoExtractCompiledGraph')
Comment on lines +577 to +578
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will run_node call extract_compiled_graph too?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's hard to see from documentations. However, when I try comparing metrics before/after run_node, from what I can see, it's not calling extract_compiled_graph.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok then I am confused what this dynamo_extract_graph_helper_metric_count is doing here

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code (run_node) is executed when we're fetching the fallback ops. And in this code below, we clear our metric counters via metrics.clear_counters(). So we need a way to restore this counter, so we can verify extract_compiled_graph only gets called once in our unit tests.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I see, I can fix it later. I think the right thing to do is to define a region where counter does not incremented.


metrics.clear_counters()
result = super().run_node(n)
fallback_ops = get_fallback_ops()
Expand Down Expand Up @@ -563,6 +621,10 @@ def all_tensors_on_xla_device(value):
if not (result_is_supported and args_are_supported):
self._unsupported_nodes.append(n)

# Restore this metric counter
torch_xla._XLAC._xla_increment_counter(
'DynamoExtractCompiledGraph', dynamo_extract_graph_helper_metric_count)

return result

def get_unsupported_nodes(self):
Expand Down Expand Up @@ -605,7 +667,8 @@ def allow_cpu_device(self, node: torch.fx.Node):


def extract_compiled_graph(xla_model: torch.fx.GraphModule, xla_args):
# graph extraction must happens under tracing mode
torch_xla._XLAC._xla_increment_counter('DynamoExtractCompiledGraph', 1)

with torch_xla.experimental.eager_mode_context(False):
return extract_compiled_graph_helper(xla_model, xla_args)

Expand Down Expand Up @@ -636,10 +699,14 @@ def extract_compiled_graph_helper(xla_model: torch.fx.GraphModule, xla_args):
for name, buffer in xla_model.named_buffers():
if "self" in name:
self_args.append(buffer)
all_xla_args = list(xla_args) + self_args

# When dynamic_shape=True, TorchDynamo will pass us shapes as integers. We want to deal with the tensors only for now, so keep them separately.
all_xla_args = [
xla_arg for xla_arg in xla_args if isinstance(xla_arg, torch.Tensor)
] + self_args

for xla_arg in xla_args:
if xla_arg.device.type != 'xla':
if isinstance(xla_arg, torch.Tensor) and xla_arg.device.type != 'xla':
warnings.warn(
"Found tensor with shape " + str(xla_arg.size()) + " on " +
str(xla_arg.device) +
Expand Down
Loading