diff --git a/test/eager/test_eager_with_torch_compile.py b/test/eager/test_eager_with_torch_compile.py new file mode 100644 index 00000000000..e7604658aa5 --- /dev/null +++ b/test/eager/test_eager_with_torch_compile.py @@ -0,0 +1,56 @@ +import unittest +import sys + +import torch +import torch_xla +import torch_xla.debug.metrics as met +import torch_xla.core.xla_model as xm + + +class EagerWithTorchCompileTest(unittest.TestCase): + + @classmethod + def setUpClass(cls): + torch_xla.experimental.eager_mode(True) + + def dummy_cos_sin(self, tensor): + return torch.cos(torch.sin(tensor)) + + def test_eager_with_compile_basic(self): + met.clear_all() + self.assertTrue(torch_xla.experimental.is_eager_mode()) + device = torch_xla.device() + + # this part happens eagerly + t1 = torch.randn(5, 5, device=device) + t1 *= 5 + + t2 = self.dummy_cos_sin(t1) + t2_compiled = torch.compile(self.dummy_cos_sin, backend="openxla")(t1) + self.assertTrue(torch.allclose(t2, t2_compiled)) + xm.wait_device_ops() + # We execute one compiled graph + self.assertEqual(met.metric_data("ExecuteTime")[0], 1) + # and many eager ops + self.assertGreater(met.metric_data("EagerOpExecuteTime")[0], 5) + + +def test_eager_execute_compiled_multiple_times(self): + met.clear_all() + self.assertTrue(torch_xla.experimental.is_eager_mode()) + device = torch_xla.device() + # this part happens eagerly + t1 = torch.randn(10, 5, device=device) + t1.add_(0.5) + compiled = torch.compile(self.dummy_cos_sin, backend="openxla") + res = compiled(compiled(t1)) + self.assertTrue( + torch.allclose(res * 0.3, + self.dummy_cos_sin(self.dummy_cos_sin(t1)) * 0.3)) + xm.wait_device_ops() + self.assertEqual(met.metric_data("ExecuteTime")[0], 2) + + +if __name__ == '__main__': + test = unittest.main() + sys.exit(0 if test.result.wasSuccessful() else 1) diff --git a/test/run_tests.sh b/test/run_tests.sh index b72b00e9c8e..ae5455a1142 100755 --- a/test/run_tests.sh +++ b/test/run_tests.sh @@ -199,12 +199,13 @@ function run_xla_op_tests1 { run_save_tensor_hlo "$CDIR/spmd/test_spmd_graph_dump.py" } -# DO NOT MODIFY function run_xla_op_tests2 { run_downcast_bf16 "$CDIR/test_data_type.py" run_test "$CDIR/pjrt/test_dtypes.py" run_test "$CDIR/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py" - run_test "$CDIR/test_autocast.py" # TODO(yeounoh) this is expensive on GPU + run_test "$CDIR/test_autocast.py" + run_test "$CDIR/eager/test_eager_with_xla_compile.py" + run_test "$CDIR/eager/test_eager_with_torch_compile.py" } # All the new xla op tests should go to run_xla_op_tests3 diff --git a/torch_xla/core/dynamo_bridge.py b/torch_xla/core/dynamo_bridge.py index f6d7a3b6e00..1ef3bc97d4f 100644 --- a/torch_xla/core/dynamo_bridge.py +++ b/torch_xla/core/dynamo_bridge.py @@ -598,6 +598,12 @@ def allow_cpu_device(self, node: torch.fx.Node): def extract_compiled_graph(xla_model: torch.fx.GraphModule, xla_args): + # graph extraction must happens under tracing mode + with torch_xla.experimental.eager_mode_context(False): + return extract_compiled_graph_helper(xla_model, xla_args) + + +def extract_compiled_graph_helper(xla_model: torch.fx.GraphModule, xla_args): if _args_on_cuda(xla_args): xla_args = tuple(_maybe_move_tensors_to_device(xla_args, xm.xla_device())) diff --git a/torch_xla/csrc/debug_util.cpp b/torch_xla/csrc/debug_util.cpp index c82369045bf..43c9210f6a3 100644 --- a/torch_xla/csrc/debug_util.cpp +++ b/torch_xla/csrc/debug_util.cpp @@ -317,7 +317,7 @@ void DebugUtil::analyze_graph_execution_python_frame( endsWith(frames[1].file, "profiler.py")) { ss << debug_output_prefix << " mark_step when exiting a profiler StepTrace region\n"; - } else if ((frames[1].function == "extract_compiled_graph" || + } else if ((frames[1].function == "extract_compiled_graph_helper" || frames[1].function == "extract_internal") && endsWith(frames[1].file, "dynamo_bridge.py")) { ss << debug_output_prefix diff --git a/torch_xla/experimental/__init__.py b/torch_xla/experimental/__init__.py index 1676e7d6349..8c0be4922e5 100644 --- a/torch_xla/experimental/__init__.py +++ b/torch_xla/experimental/__init__.py @@ -1,7 +1,8 @@ -from .eager import eager_mode, compile, is_eager_mode +from .eager import eager_mode, compile, is_eager_mode, eager_mode_context __all__ = [ "eager_mode", "compile", "is_eager_mode", + "eager_mode_context", ] diff --git a/torch_xla/experimental/eager.py b/torch_xla/experimental/eager.py index 085df0419c3..bfc8dd69ca6 100644 --- a/torch_xla/experimental/eager.py +++ b/torch_xla/experimental/eager.py @@ -1,4 +1,5 @@ import functools +from contextlib import contextmanager import torch_xla @@ -18,6 +19,18 @@ def is_eager_mode() -> bool: return torch_xla._XLAC._get_use_eager_mode() +@contextmanager +def eager_mode_context(enable: bool): + """Context manager to enable/disable the eager mode. + """ + saved_eager_mode = is_eager_mode() + eager_mode(enable) + try: + yield saved_eager_mode + finally: + eager_mode(saved_eager_mode) + + def compile(func): """Compile the func with Lazy Tensor. @@ -35,12 +48,12 @@ def wrapper(*args, **kwargs): try: # Target Function Execution result = func(*args, **kwargs) + # Sync the graph generated by the target function. + torch_xla.sync() except Exception as e: # Handle exceptions (if needed) print(f"Error in target function: {e}") raise # Re-raise the exception - # Sync the graph generated by the target function. - torch_xla.sync() torch_xla._XLAC._set_use_eager_mode(True) return result