Skip to content

Commit

Permalink
use the eager_mode_context for some of the tests (#7276)
Browse files Browse the repository at this point in the history
  • Loading branch information
JackCaoG committed Jun 18, 2024
1 parent e8601ea commit 1492eba
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 27 deletions.
19 changes: 9 additions & 10 deletions test/debug_tool/test_pt_xla_debug.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,16 +30,15 @@ def setUpClass(cls):
open(cls.debug_file_name, 'w').close()

def test_eager_mark_step(self):
torch_xla.experimental.eager_mode(True)
device = xm.xla_device()
t1 = torch.randn(5, 9, device=device)
xm.mark_step()
with open(self.debug_file_name, 'rb') as f:
lines = f.readlines()
# We expect PT_XLA_BUDEG not to output anything under the eager mode
self.assertEqual(len(lines), 0)
torch_xla.experimental.eager_mode(False)
open(self.debug_file_name, 'w').close()
with torch_xla.experimental.eager_mode_context(True):
device = xm.xla_device()
t1 = torch.randn(5, 9, device=device)
xm.mark_step()
with open(self.debug_file_name, 'rb') as f:
lines = f.readlines()
# We expect PT_XLA_BUDEG not to output anything under the eager mode
self.assertEqual(len(lines), 0)
open(self.debug_file_name, 'w').close()

def test_user_mark_step(self):
device = xm.xla_device()
Expand Down
33 changes: 16 additions & 17 deletions test/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,23 +53,22 @@ def test_tracing_time_metrics(self):
self.assertGreater(met.metric_data('LazyTracing')[0], 1)

def test_eager_metrics(self):
torch_xla.experimental.eager_mode(True)
xla_device = xm.xla_device()
met.clear_all()
t1 = torch.tensor(156, device=xla_device)
t2 = t1 + 100
xm.wait_device_ops()
self.assertIn('EagerOpCompileTime', met.metric_names())
# one for cosntant, one for add
self.assertEqual(met.metric_data('EagerOpCompileTime')[0], 2)
self.assertIn('EagerOpExecuteTime', met.metric_names())
# one for add
self.assertEqual(met.metric_data('EagerOpExecuteTime')[0], 2)
# mark_step should be a no-op
xm.mark_step()
self.assertNotIn('CompileTime', met.metric_names())
self.assertNotIn('ExecuteTime', met.metric_names())
torch_xla.experimental.eager_mode(False)
with torch_xla.experimental.eager_mode_context(True):
xla_device = xm.xla_device()
met.clear_all()
t1 = torch.tensor(156, device=xla_device)
t2 = t1 + 100
xm.wait_device_ops()
self.assertIn('EagerOpCompileTime', met.metric_names())
# one for cosntant, one for add
self.assertEqual(met.metric_data('EagerOpCompileTime')[0], 2)
self.assertIn('EagerOpExecuteTime', met.metric_names())
# one for add
self.assertEqual(met.metric_data('EagerOpExecuteTime')[0], 2)
# mark_step should be a no-op
xm.mark_step()
self.assertNotIn('CompileTime', met.metric_names())
self.assertNotIn('ExecuteTime', met.metric_names())

def test_short_metrics_report_default_list(self):
xla_device = xm.xla_device()
Expand Down

0 comments on commit 1492eba

Please sign in to comment.