From bdd00e54043ff9990b63d3e671b28f9be3de5c48 Mon Sep 17 00:00:00 2001 From: JackCaoG <59073027+JackCaoG@users.noreply.github.com> Date: Mon, 29 Jul 2024 12:48:28 -0700 Subject: [PATCH] Refactor torch_xla.experimental.compile and move it out of experimental (#7750) --- docs/source/index.rst | 3 +-- examples/train_decoder_only_base.py | 2 +- examples/train_resnet_base.py | 2 +- test/eager/test_eager_with_xla_compile.py | 27 +++++++++++++------ test/test_devices.py | 9 ------- torch_xla/experimental/eager.py | 32 +++++----------------- torch_xla/torch_xla.py | 33 ++++++++++++++++++++++- 7 files changed, 60 insertions(+), 48 deletions(-) diff --git a/docs/source/index.rst b/docs/source/index.rst index 5ebf1af2f80..4652a2115b8 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -28,7 +28,7 @@ torch_xla .. autofunction:: devices .. autofunction:: device_count .. autofunction:: sync -.. autofunction:: step +.. autofunction:: compile .. autofunction:: manual_seed runtime @@ -96,7 +96,6 @@ experimental ---------------------------------- .. automodule:: torch_xla.experimental .. autofunction:: eager_mode -.. autofunction:: compile debug ---------------------------------- diff --git a/examples/train_decoder_only_base.py b/examples/train_decoder_only_base.py index 9892d48172e..908f67c44ac 100644 --- a/examples/train_decoder_only_base.py +++ b/examples/train_decoder_only_base.py @@ -35,7 +35,7 @@ def __init__(self): self.optimizer = torch.optim.Adam(self.model.parameters(), lr=0.0001) self.loss_fn = nn.CrossEntropyLoss() # Compile the step fn - self.compiled_step_fn = torch_xla.experimental.compile(self.step_fn) + self.compiled_step_fn = torch_xla.compile(self.step_fn) def _train_update(self, step, loss, tracker, epoch): print(f'epoch: {epoch}, step: {step}, loss: {loss}, rate: {tracker.rate()}') diff --git a/examples/train_resnet_base.py b/examples/train_resnet_base.py index 1f86e6fe17d..b47d821bd5a 100644 --- a/examples/train_resnet_base.py +++ b/examples/train_resnet_base.py @@ -33,7 +33,7 @@ def __init__(self): self.model = torchvision.models.resnet50().to(self.device) self.optimizer = optim.SGD(self.model.parameters(), weight_decay=1e-4) self.loss_fn = nn.CrossEntropyLoss() - self.compiled_step_fn = torch_xla.experimental.compile(self.step_fn) + self.compiled_step_fn = torch_xla.compile(self.step_fn) def _train_update(self, step, loss, tracker, epoch): print(f'epoch: {epoch}, step: {step}, loss: {loss}, rate: {tracker.rate()}') diff --git a/test/eager/test_eager_with_xla_compile.py b/test/eager/test_eager_with_xla_compile.py index 0509316a0ef..281da29c155 100644 --- a/test/eager/test_eager_with_xla_compile.py +++ b/test/eager/test_eager_with_xla_compile.py @@ -13,6 +13,10 @@ class EagerWithXLACompileTest(unittest.TestCase): def setUpClass(cls): torch_xla.experimental.eager_mode(True) + @torch_xla.compile + def dummy_cos_sin_decored(self, tensor): + return torch.cos(torch.sin(tensor)) + def dummy_cos_sin(self, tensor): return torch.cos(torch.sin(tensor)) @@ -24,15 +28,22 @@ def test_eager_with_compile_basic(self): # this part happens eagerly t1 = torch.randn(5, 5, device=device) t1 *= 5 + self.assertGreater(met.metric_data("EagerOpExecuteTime")[0], 1) t2 = self.dummy_cos_sin(t1) - t2_compiled = torch_xla.experimental.compile(self.dummy_cos_sin)(t1) - self.assertTrue(torch.allclose(t2, t2_compiled)) - xm.wait_device_ops() - # We execute one compiled graph - self.assertEqual(met.metric_data("ExecuteTime")[0], 1) - # and many eager ops - self.assertGreater(met.metric_data("EagerOpExecuteTime")[0], 5) + for compiled in [ + self.dummy_cos_sin_decored, + torch_xla.compile(self.dummy_cos_sin) + ]: + xm.wait_device_ops() + met.clear_all() + t2_compiled = compiled(t1) + self.assertTrue(torch.allclose(t2.cpu(), t2_compiled.cpu())) + xm.wait_device_ops() + # We execute one compiled graph + self.assertEqual(met.metric_data("ExecuteTime")[0], 1) + # no egaer execution should happen inside this compiled graph + self.assertNotIn("EagerOpExecuteTime", met.metric_names()) def test_eager_execute_compiled_multiple_times(self): @@ -42,7 +53,7 @@ def test_eager_execute_compiled_multiple_times(self): # this part happens eagerly t1 = torch.randn(10, 5, device=device) t1.add_(0.5) - compiled = torch_xla.experimental.compile(self.dummy_cos_sin) + compiled = torch_xla.compile(self.dummy_cos_sin) res = compiled(compiled(t1)) self.assertTrue( torch.allclose(res * 0.3, diff --git a/test/test_devices.py b/test/test_devices.py index 0b10fc056cb..259f0046623 100644 --- a/test/test_devices.py +++ b/test/test_devices.py @@ -57,15 +57,6 @@ def test_step_exception(self): self.assertEqual(met.counter_value('MarkStep'), 2) - def test_step_decorator(self): - - @xla.step - def f(): - torch.ones((3, 3), device=xla.device()) - - f() - self.assertEqual(met.counter_value('MarkStep'), 2) - # Should roughly match example given in README def test_trivial_model(self): diff --git a/torch_xla/experimental/eager.py b/torch_xla/experimental/eager.py index 7a5d9c88ec6..df714298451 100644 --- a/torch_xla/experimental/eager.py +++ b/torch_xla/experimental/eager.py @@ -2,6 +2,7 @@ from contextlib import contextmanager import torch_xla +import logging def eager_mode(enable: bool): @@ -32,29 +33,8 @@ def eager_mode_context(enable: bool): def compile(func): - """Compile the func with Lazy Tensor. - - Return the optimized function that takes exact same input. Compile will - run the target func under the tracing mode using Lazy tensor. - """ - - @functools.wraps(func) # Keep function's name, docstring, etc. - def wrapper(*args, **kwargs): - saved_eager_mode_status = torch_xla._XLAC._get_use_eager_mode() - torch_xla._XLAC._set_use_eager_mode(False) - # clear the pending graph if any - torch_xla.sync() - try: - # Target Function Execution - result = func(*args, **kwargs) - # Sync the graph generated by the target function. - torch_xla.sync() - except Exception as e: - # Handle exceptions (if needed) - print(f"Error in target function: {e}") - raise # Re-raise the exception - torch_xla._XLAC._set_use_eager_mode(saved_eager_mode_status) - - return result - - return wrapper + # can's use deprecated wrapper at import time due to circular dependency + logging.warning( + 'torch_xla.experimental.compile is deprecated. Use torch_xla.compile instead.' + ) + return torch_xla.compile(func) diff --git a/torch_xla/torch_xla.py b/torch_xla/torch_xla.py index c07e051dff7..2aca101a7e6 100644 --- a/torch_xla/torch_xla.py +++ b/torch_xla/torch_xla.py @@ -61,13 +61,44 @@ def sync(): xm.mark_step() -def step(f: Optional[Callable] = None): +def step(): """Wraps code that should be dispatched to the runtime. Experimental: `xla.step` is still a work in progress. Some code that currently works with `xla.step` but does not follow best practices will become errors in future releases. See https://github.com/pytorch/xla/issues/6751 for context. """ + return compile() + + +def compile(f: Optional[Callable] = None): + """ + Optimizes given model/function using torch_xla's LazyTensor tracing mode. + PyTorch/XLA will trace the given function with given inputs and then generate + graphs to represent the pytorch operations happens within this function. This + graph will be compiled by the XLA and executed on the accelerator(decided by the + tensor's device). Eager mode will be disabled for the compiled region of the funciton. + + Args: + model (Callable): Module/function to optimize, if not passed this function will + act as a context manager. + + Example:: + + # usage 1 + @torch_xla.compile() + def foo(x): + return torch.sin(x) + torch.cos(x) + + def foo2(x): + return torch.sin(x) + torch.cos(x) + # usage 2 + compiled_foo2 = torch_xla.compile(foo2) + + # usage 3 + with torch_xla.compile(): + res = foo2(x) + """ @contextlib.contextmanager def _step():