Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for dynamic shape in dynamo #7676

Merged
merged 12 commits into from
Jul 23, 2024
Merged

Conversation

wonjoolee95
Copy link
Collaborator

Fixes #7614


TODO

  • Remove debugging code and add comments
  • Add unit tests
  • Handle error case when TorchDynamo passes us int types

@wonjoolee95
Copy link
Collaborator Author

With the current changes, the following code generates correct results without recompiling the graph:

    ###
    # torch.compile dynamic shape ON
    torch._dynamo.config.automatic_dynamic_shapes = True
    compiled_fn = torch.compile(fn, backend='openxla', dynamic=True)
    a = torch.randn(3, 4, device=device)
    b = torch.ones(4, device=device)
    ret = compiled_fn(a, b)
    xm.mark_step()
    print(f'[Testing] {ret=}')
    print(f'--------------------')

    c = torch.randn(4, 5, device=device)
    d = torch.ones(5, device=device)
    ret2 = compiled_fn(c, d)
    xm.mark_step()
    print(f'[Testing] {ret2=}')
    print(f'--------------------')

As for next steps, I'll clean up some code and add some unit tests.

@wonjoolee95 wonjoolee95 changed the title [WIP] Add support for dynamic shape in dynamo Add support for dynamic shape in dynamo Jul 15, 2024
@wonjoolee95 wonjoolee95 marked this pull request as ready for review July 15, 2024 21:23
@JackCaoG
Copy link
Collaborator

seems like a bunch of test failed and a lot of them are real failures. @wonjoolee95 let me know if you need help debugging them

Comment on lines 415 to 416
# self.assertTrue(
# torch.allclose(output_cpu_new_shape, output_new_shape.cpu(), rtol=1e-05, atol=1e-05))
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This part is odd. When I run these tests, the allclose fails because in some iteration of the data loader with this new_shape, the differences are as big as 0.2.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is fixed with the explicit mark_step call within else statement under torch._dynamo.config.assume_static_by_default.

for data, _ in loader_new_shape:
output_new_shape = dynamo_resnet18(data)
output_cpu_new_shape = resnet18(data.cpu())
# # TPU has some precision issues, skipping allclose check
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove one #

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated

output_new_shape.cpu(),
rtol=1e-05,
atol=1e-05))

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe also check the CompileTime and ExecuteTime here

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also can you make another test to test the case of

fn(shape_a)
fn(shape_b)
fn(shape_c)
fn(shape_a)

want to make sure we don't forgot the old shapes that's cached.

# Values: tuple of (xla_args_sharding_spec, args_and_out, graph_hash,
# arg_index_to_need_update_index, none_remover, graph_input_matcher,
# dumb_return_handler, xla_args_need_update).
input_shape_mappings: dict[tuple[int, ...], tuple[object, ...]] = {}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ust typing.Dict and typing.Tuple otherwise the python 3.8 CI in upstream will fail

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated

Comment on lines +499 to +502
input_shape_mappings[arg_input_shapes] = (
xla_args_sharding_spec, args_and_out, graph_hash,
arg_index_to_need_update_index, none_remover, graph_input_matcher,
dumb_return_handler, xla_args_need_update)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you don't need this here

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC, we actually need this here. And we actually don't need this same logic in extract_internal above (removed this in the newest commit). The reason is when dynamic=True, only optimized_mod is called. Other functions (including extract_internal) are not called.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok then you will run into the same old problem right?
first time

extract_graph_helper -> optimized_mod

in this case you do the compile, but you do not cache the input_shape_mappings

when optimized_mod is called the first tiem you will need to call extract_graph_helper again which is wasteful.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you should just do the caching(input_shape_mappings[arg_input_shapes] =) inside the extract_graph_helper

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let me fix this too..

Comment on lines +587 to +588
dynamo_extract_graph_helper_metric_count = metrics.counter_value(
'DynamoExtractCompiledGraph')
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will run_node call extract_compiled_graph too?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's hard to see from documentations. However, when I try comparing metrics before/after run_node, from what I can see, it's not calling extract_compiled_graph.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok then I am confused what this dynamo_extract_graph_helper_metric_count is doing here

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code (run_node) is executed when we're fetching the fallback ops. And in this code below, we clear our metric counters via metrics.clear_counters(). So we need a way to restore this counter, so we can verify extract_compiled_graph only gets called once in our unit tests.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I see, I can fix it later. I think the right thing to do is to define a region where counter does not incremented.

@JackCaoG
Copy link
Collaborator

@ysiraichi FYI

@wonjoolee95
Copy link
Collaborator Author

The PR should be in a reasonable state, now just seeing 2 failures the GPU tests requiring torch CUDA tests:

#1: DynamoInferenceBasicTest.test_dynamic_shape_resnet180 (True):
Input tensor is not an XLA tensor: CUDAFloatType

#2: DynamoInferenceBasicTest.test_resnet180 (True)
  File "/__w/xla/xla/pytorch/xla/test/dynamo/test_dynamo.py", line 370, in test_resnet18
    self.assertEqual(met.metric_data('CompileTime')[0], 1)
TypeError: 'NoneType' object is not subscriptable

For the first error, the stack trace points to:

pytree.tree_map_only(
torch.Tensor,
lambda xla_arg: torch_xla._XLAC._xla_get_tensor_id(xla_arg),
xla_args))

It seems like we may want to do an additional isinstance(arg, torch.Tensor) check here.

@JackCaoG
Copy link
Collaborator

I will pick this up and try to fix error today

@JackCaoG JackCaoG added the tpuci label Jul 22, 2024
@JackCaoG
Copy link
Collaborator

@alanwaketan There are a few places I want to fix but maybe we should just merge this pr to unblock Woosuk now. I am also running some benchmarks

Copy link
Collaborator

@alanwaketan alanwaketan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Approved to unblock.

@JackCaoG JackCaoG merged commit 2b6b461 into master Jul 23, 2024
23 checks passed
@JackCaoG JackCaoG deleted the wonjoo/dynamo-dynamic-shape branch July 23, 2024 01:18
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
dynamism Dynamic Shape Features tpuci
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Dynamo persistent cache real-time look-up
4 participants