Skip to content

Commit

Permalink
Refactor torch_xla.experimental.compile and move it out of experiment…
Browse files Browse the repository at this point in the history
…al (#7750)
  • Loading branch information
JackCaoG committed Jul 29, 2024
1 parent 009e31a commit bdd00e5
Show file tree
Hide file tree
Showing 7 changed files with 60 additions and 48 deletions.
3 changes: 1 addition & 2 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ torch_xla
.. autofunction:: devices
.. autofunction:: device_count
.. autofunction:: sync
.. autofunction:: step
.. autofunction:: compile
.. autofunction:: manual_seed

runtime
Expand Down Expand Up @@ -96,7 +96,6 @@ experimental
----------------------------------
.. automodule:: torch_xla.experimental
.. autofunction:: eager_mode
.. autofunction:: compile

debug
----------------------------------
Expand Down
2 changes: 1 addition & 1 deletion examples/train_decoder_only_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def __init__(self):
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=0.0001)
self.loss_fn = nn.CrossEntropyLoss()
# Compile the step fn
self.compiled_step_fn = torch_xla.experimental.compile(self.step_fn)
self.compiled_step_fn = torch_xla.compile(self.step_fn)

def _train_update(self, step, loss, tracker, epoch):
print(f'epoch: {epoch}, step: {step}, loss: {loss}, rate: {tracker.rate()}')
Expand Down
2 changes: 1 addition & 1 deletion examples/train_resnet_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def __init__(self):
self.model = torchvision.models.resnet50().to(self.device)
self.optimizer = optim.SGD(self.model.parameters(), weight_decay=1e-4)
self.loss_fn = nn.CrossEntropyLoss()
self.compiled_step_fn = torch_xla.experimental.compile(self.step_fn)
self.compiled_step_fn = torch_xla.compile(self.step_fn)

def _train_update(self, step, loss, tracker, epoch):
print(f'epoch: {epoch}, step: {step}, loss: {loss}, rate: {tracker.rate()}')
Expand Down
27 changes: 19 additions & 8 deletions test/eager/test_eager_with_xla_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@ class EagerWithXLACompileTest(unittest.TestCase):
def setUpClass(cls):
torch_xla.experimental.eager_mode(True)

@torch_xla.compile
def dummy_cos_sin_decored(self, tensor):
return torch.cos(torch.sin(tensor))

def dummy_cos_sin(self, tensor):
return torch.cos(torch.sin(tensor))

Expand All @@ -24,15 +28,22 @@ def test_eager_with_compile_basic(self):
# this part happens eagerly
t1 = torch.randn(5, 5, device=device)
t1 *= 5
self.assertGreater(met.metric_data("EagerOpExecuteTime")[0], 1)

t2 = self.dummy_cos_sin(t1)
t2_compiled = torch_xla.experimental.compile(self.dummy_cos_sin)(t1)
self.assertTrue(torch.allclose(t2, t2_compiled))
xm.wait_device_ops()
# We execute one compiled graph
self.assertEqual(met.metric_data("ExecuteTime")[0], 1)
# and many eager ops
self.assertGreater(met.metric_data("EagerOpExecuteTime")[0], 5)
for compiled in [
self.dummy_cos_sin_decored,
torch_xla.compile(self.dummy_cos_sin)
]:
xm.wait_device_ops()
met.clear_all()
t2_compiled = compiled(t1)
self.assertTrue(torch.allclose(t2.cpu(), t2_compiled.cpu()))
xm.wait_device_ops()
# We execute one compiled graph
self.assertEqual(met.metric_data("ExecuteTime")[0], 1)
# no egaer execution should happen inside this compiled graph
self.assertNotIn("EagerOpExecuteTime", met.metric_names())


def test_eager_execute_compiled_multiple_times(self):
Expand All @@ -42,7 +53,7 @@ def test_eager_execute_compiled_multiple_times(self):
# this part happens eagerly
t1 = torch.randn(10, 5, device=device)
t1.add_(0.5)
compiled = torch_xla.experimental.compile(self.dummy_cos_sin)
compiled = torch_xla.compile(self.dummy_cos_sin)
res = compiled(compiled(t1))
self.assertTrue(
torch.allclose(res * 0.3,
Expand Down
9 changes: 0 additions & 9 deletions test/test_devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,15 +57,6 @@ def test_step_exception(self):

self.assertEqual(met.counter_value('MarkStep'), 2)

def test_step_decorator(self):

@xla.step
def f():
torch.ones((3, 3), device=xla.device())

f()
self.assertEqual(met.counter_value('MarkStep'), 2)

# Should roughly match example given in README
def test_trivial_model(self):

Expand Down
32 changes: 6 additions & 26 deletions torch_xla/experimental/eager.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from contextlib import contextmanager

import torch_xla
import logging


def eager_mode(enable: bool):
Expand Down Expand Up @@ -32,29 +33,8 @@ def eager_mode_context(enable: bool):


def compile(func):
"""Compile the func with Lazy Tensor.
Return the optimized function that takes exact same input. Compile will
run the target func under the tracing mode using Lazy tensor.
"""

@functools.wraps(func) # Keep function's name, docstring, etc.
def wrapper(*args, **kwargs):
saved_eager_mode_status = torch_xla._XLAC._get_use_eager_mode()
torch_xla._XLAC._set_use_eager_mode(False)
# clear the pending graph if any
torch_xla.sync()
try:
# Target Function Execution
result = func(*args, **kwargs)
# Sync the graph generated by the target function.
torch_xla.sync()
except Exception as e:
# Handle exceptions (if needed)
print(f"Error in target function: {e}")
raise # Re-raise the exception
torch_xla._XLAC._set_use_eager_mode(saved_eager_mode_status)

return result

return wrapper
# can's use deprecated wrapper at import time due to circular dependency
logging.warning(
'torch_xla.experimental.compile is deprecated. Use torch_xla.compile instead.'
)
return torch_xla.compile(func)
33 changes: 32 additions & 1 deletion torch_xla/torch_xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,13 +61,44 @@ def sync():
xm.mark_step()


def step(f: Optional[Callable] = None):
def step():
"""Wraps code that should be dispatched to the runtime.
Experimental: `xla.step` is still a work in progress. Some code that currently
works with `xla.step` but does not follow best practices will become errors in
future releases. See https://github.com/pytorch/xla/issues/6751 for context.
"""
return compile()


def compile(f: Optional[Callable] = None):
"""
Optimizes given model/function using torch_xla's LazyTensor tracing mode.
PyTorch/XLA will trace the given function with given inputs and then generate
graphs to represent the pytorch operations happens within this function. This
graph will be compiled by the XLA and executed on the accelerator(decided by the
tensor's device). Eager mode will be disabled for the compiled region of the funciton.
Args:
model (Callable): Module/function to optimize, if not passed this function will
act as a context manager.
Example::
# usage 1
@torch_xla.compile()
def foo(x):
return torch.sin(x) + torch.cos(x)
def foo2(x):
return torch.sin(x) + torch.cos(x)
# usage 2
compiled_foo2 = torch_xla.compile(foo2)
# usage 3
with torch_xla.compile():
res = foo2(x)
"""

@contextlib.contextmanager
def _step():
Expand Down

0 comments on commit bdd00e5

Please sign in to comment.