Skip to content

Commit

Permalink
support eager mode with torch.compile (#7256)
Browse files Browse the repository at this point in the history
  • Loading branch information
JackCaoG committed Jun 13, 2024
1 parent 0548ec3 commit 192dec6
Show file tree
Hide file tree
Showing 6 changed files with 83 additions and 6 deletions.
56 changes: 56 additions & 0 deletions test/eager/test_eager_with_torch_compile.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import unittest
import sys

import torch
import torch_xla
import torch_xla.debug.metrics as met
import torch_xla.core.xla_model as xm


class EagerWithTorchCompileTest(unittest.TestCase):

@classmethod
def setUpClass(cls):
torch_xla.experimental.eager_mode(True)

def dummy_cos_sin(self, tensor):
return torch.cos(torch.sin(tensor))

def test_eager_with_compile_basic(self):
met.clear_all()
self.assertTrue(torch_xla.experimental.is_eager_mode())
device = torch_xla.device()

# this part happens eagerly
t1 = torch.randn(5, 5, device=device)
t1 *= 5

t2 = self.dummy_cos_sin(t1)
t2_compiled = torch.compile(self.dummy_cos_sin, backend="openxla")(t1)
self.assertTrue(torch.allclose(t2, t2_compiled))
xm.wait_device_ops()
# We execute one compiled graph
self.assertEqual(met.metric_data("ExecuteTime")[0], 1)
# and many eager ops
self.assertGreater(met.metric_data("EagerOpExecuteTime")[0], 5)


def test_eager_execute_compiled_multiple_times(self):
met.clear_all()
self.assertTrue(torch_xla.experimental.is_eager_mode())
device = torch_xla.device()
# this part happens eagerly
t1 = torch.randn(10, 5, device=device)
t1.add_(0.5)
compiled = torch.compile(self.dummy_cos_sin, backend="openxla")
res = compiled(compiled(t1))
self.assertTrue(
torch.allclose(res * 0.3,
self.dummy_cos_sin(self.dummy_cos_sin(t1)) * 0.3))
xm.wait_device_ops()
self.assertEqual(met.metric_data("ExecuteTime")[0], 2)


if __name__ == '__main__':
test = unittest.main()
sys.exit(0 if test.result.wasSuccessful() else 1)
5 changes: 3 additions & 2 deletions test/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -199,12 +199,13 @@ function run_xla_op_tests1 {
run_save_tensor_hlo "$CDIR/spmd/test_spmd_graph_dump.py"
}

# DO NOT MODIFY
function run_xla_op_tests2 {
run_downcast_bf16 "$CDIR/test_data_type.py"
run_test "$CDIR/pjrt/test_dtypes.py"
run_test "$CDIR/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py"
run_test "$CDIR/test_autocast.py" # TODO(yeounoh) this is expensive on GPU
run_test "$CDIR/test_autocast.py"
run_test "$CDIR/eager/test_eager_with_xla_compile.py"
run_test "$CDIR/eager/test_eager_with_torch_compile.py"
}

# All the new xla op tests should go to run_xla_op_tests3
Expand Down
6 changes: 6 additions & 0 deletions torch_xla/core/dynamo_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -598,6 +598,12 @@ def allow_cpu_device(self, node: torch.fx.Node):


def extract_compiled_graph(xla_model: torch.fx.GraphModule, xla_args):
# graph extraction must happens under tracing mode
with torch_xla.experimental.eager_mode_context(False):
return extract_compiled_graph_helper(xla_model, xla_args)


def extract_compiled_graph_helper(xla_model: torch.fx.GraphModule, xla_args):
if _args_on_cuda(xla_args):
xla_args = tuple(_maybe_move_tensors_to_device(xla_args, xm.xla_device()))

Expand Down
2 changes: 1 addition & 1 deletion torch_xla/csrc/debug_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,7 @@ void DebugUtil::analyze_graph_execution_python_frame(
endsWith(frames[1].file, "profiler.py")) {
ss << debug_output_prefix
<< " mark_step when exiting a profiler StepTrace region\n";
} else if ((frames[1].function == "extract_compiled_graph" ||
} else if ((frames[1].function == "extract_compiled_graph_helper" ||
frames[1].function == "extract_internal") &&
endsWith(frames[1].file, "dynamo_bridge.py")) {
ss << debug_output_prefix
Expand Down
3 changes: 2 additions & 1 deletion torch_xla/experimental/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from .eager import eager_mode, compile, is_eager_mode
from .eager import eager_mode, compile, is_eager_mode, eager_mode_context

__all__ = [
"eager_mode",
"compile",
"is_eager_mode",
"eager_mode_context",
]
17 changes: 15 additions & 2 deletions torch_xla/experimental/eager.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import functools
from contextlib import contextmanager

import torch_xla

Expand All @@ -18,6 +19,18 @@ def is_eager_mode() -> bool:
return torch_xla._XLAC._get_use_eager_mode()


@contextmanager
def eager_mode_context(enable: bool):
"""Context manager to enable/disable the eager mode.
"""
saved_eager_mode = is_eager_mode()
eager_mode(enable)
try:
yield saved_eager_mode
finally:
eager_mode(saved_eager_mode)


def compile(func):
"""Compile the func with Lazy Tensor.
Expand All @@ -35,12 +48,12 @@ def wrapper(*args, **kwargs):
try:
# Target Function Execution
result = func(*args, **kwargs)
# Sync the graph generated by the target function.
torch_xla.sync()
except Exception as e:
# Handle exceptions (if needed)
print(f"Error in target function: {e}")
raise # Re-raise the exception
# Sync the graph generated by the target function.
torch_xla.sync()
torch_xla._XLAC._set_use_eager_mode(True)

return result
Expand Down

0 comments on commit 192dec6

Please sign in to comment.