Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use torch_xla.experimental.compile for all examples #7642

Merged
merged 2 commits into from
Jul 25, 2024

Conversation

JackCaoG
Copy link
Collaborator

@JackCaoG JackCaoG commented Jul 9, 2024

Here is my plan, @will-cromar let me know what you think

  1. ask users to use torch_xla.experimental.compile(and torch_xla.step()) to wrap their step fn which will mark_step the outside region for them
  2. ask users to turn on eager mode later

Pretty much we can stage the effort for this ux migration. For this to work I need to make torch_xla.step handles eager mode, I will do that in a follow up pr.

@JackCaoG
Copy link
Collaborator Author

JackCaoG commented Jul 9, 2024

@will-cromar I think this is ready for review

@@ -33,23 +33,27 @@ def __init__(self):
self.model = torchvision.models.resnet50().to(self.device)
self.optimizer = optim.SGD(self.model.parameters(), weight_decay=1e-4)
self.loss_fn = nn.CrossEntropyLoss()
self.compiled_step_fn = torch_xla.experimental.compile(self.step_fn)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm still not sure I understand why compile is separate from torch_xla.step. The main difference that I can see is that you flip eager off then on, which I think you can just add to torch_xla.step. contextlib.contextmanager already handles the rest of the plumbing you have there such as wrapping and exception handling.

@functools.wraps(func) # Keep function's name, docstring, etc.
def wrapper(*args, **kwargs):
# compile should only be called with
assert torch_xla._XLAC._get_use_eager_mode() == True
torch_xla._XLAC._set_use_eager_mode(False)
# clear the pending graph if any
torch_xla.sync()
try:
# Target Function Execution
result = func(*args, **kwargs)
# Sync the graph generated by the target function.
torch_xla.sync()
except Exception as e:
# Handle exceptions (if needed)
print(f"Error in target function: {e}")
raise # Re-raise the exception
torch_xla._XLAC._set_use_eager_mode(True)
return result

vs

@contextlib.contextmanager
def step():
"""Wraps code that should be dispatched to the runtime.
Experimental: `xla.step` is still a work in progress. Some code that currently
works with `xla.step` but does not follow best practices will become errors in
future releases. See https://github.com/pytorch/xla/issues/6751 for context.
"""
# Clear pending operations
xm.mark_step()
try:
yield
finally:
xm.mark_step()

https://docs.python.org/3/library/contextlib.html#contextlib.contextmanager

https://docs.python.org/3/library/contextlib.html#contextlib.ContextDecorator

step already has a cautionary note in the docstring that we will be changing the public API over time.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess I don't know how do you turn a context manger to a wrapper around the function? All I want was that api looks like

torch_xla.experimental.compile()

I think it can share the same implementation as the step.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh ok I guess you meant to use @functools.wraps(func) on top of the step.. let me try that..

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok they are actually different, step is a context manager that doesn't take fn and arguments as input. It also doesn't return the fn's output back.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok they are actually different, step is a context manager that doesn't take fn and arguments as input. It also doesn't return the fn's output back.

Context managers and decorators are ~interchangable in python thanks to ContextDecorator:

@torch_xla.experimental.compile
def f(x):
  return 2 * x

@torch_xla.step()
def g(x):
  return 2 * x

The only semantic difference is you have to instantiate the context manager first (step()). It appropriately wraps the functions inputs, outputs, docstring, signature, types, etc too. If we want a "no arguments" form of step, then I can make that work. Even something silly like this works:

stepwithnoargs = torch_xla.step() # not recommended for readability lol

@stepwithnoargs
def g(x):
  """docstring"""
  return 2 * x

Python has almost exactly what we want in the standard library already, so let's use it.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so many new python hacks haha, let me read into this a bit

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand how decorators and context manager are interchangeable now, I am still not sure how to implement the api I want new_fn = torch.experimental.compile(func) with the decorator, this is a bit mind twisting. Are you going to make step takes an optional argument and branch from there? I will update step to handle eager mode and you can probably refactor from there.

@JackCaoG JackCaoG force-pushed the JackCaoG/update_example_explicy_sync branch from 953f95d to a007c69 Compare July 17, 2024 23:11
@JackCaoG JackCaoG merged commit 735dd43 into master Jul 25, 2024
23 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants