From 1492eba82dd6f1799d73902443272069b06ff732 Mon Sep 17 00:00:00 2001 From: JackCaoG <59073027+JackCaoG@users.noreply.github.com> Date: Mon, 17 Jun 2024 20:15:31 -0700 Subject: [PATCH] use the eager_mode_context for some of the tests (#7276) --- test/debug_tool/test_pt_xla_debug.py | 19 ++++++++-------- test/test_metrics.py | 33 ++++++++++++++-------------- 2 files changed, 25 insertions(+), 27 deletions(-) diff --git a/test/debug_tool/test_pt_xla_debug.py b/test/debug_tool/test_pt_xla_debug.py index ecc0f2c0520..becc2c321e8 100644 --- a/test/debug_tool/test_pt_xla_debug.py +++ b/test/debug_tool/test_pt_xla_debug.py @@ -30,16 +30,15 @@ def setUpClass(cls): open(cls.debug_file_name, 'w').close() def test_eager_mark_step(self): - torch_xla.experimental.eager_mode(True) - device = xm.xla_device() - t1 = torch.randn(5, 9, device=device) - xm.mark_step() - with open(self.debug_file_name, 'rb') as f: - lines = f.readlines() - # We expect PT_XLA_BUDEG not to output anything under the eager mode - self.assertEqual(len(lines), 0) - torch_xla.experimental.eager_mode(False) - open(self.debug_file_name, 'w').close() + with torch_xla.experimental.eager_mode_context(True): + device = xm.xla_device() + t1 = torch.randn(5, 9, device=device) + xm.mark_step() + with open(self.debug_file_name, 'rb') as f: + lines = f.readlines() + # We expect PT_XLA_BUDEG not to output anything under the eager mode + self.assertEqual(len(lines), 0) + open(self.debug_file_name, 'w').close() def test_user_mark_step(self): device = xm.xla_device() diff --git a/test/test_metrics.py b/test/test_metrics.py index cabd2e768b8..e4322e2ad01 100644 --- a/test/test_metrics.py +++ b/test/test_metrics.py @@ -53,23 +53,22 @@ def test_tracing_time_metrics(self): self.assertGreater(met.metric_data('LazyTracing')[0], 1) def test_eager_metrics(self): - torch_xla.experimental.eager_mode(True) - xla_device = xm.xla_device() - met.clear_all() - t1 = torch.tensor(156, device=xla_device) - t2 = t1 + 100 - xm.wait_device_ops() - self.assertIn('EagerOpCompileTime', met.metric_names()) - # one for cosntant, one for add - self.assertEqual(met.metric_data('EagerOpCompileTime')[0], 2) - self.assertIn('EagerOpExecuteTime', met.metric_names()) - # one for add - self.assertEqual(met.metric_data('EagerOpExecuteTime')[0], 2) - # mark_step should be a no-op - xm.mark_step() - self.assertNotIn('CompileTime', met.metric_names()) - self.assertNotIn('ExecuteTime', met.metric_names()) - torch_xla.experimental.eager_mode(False) + with torch_xla.experimental.eager_mode_context(True): + xla_device = xm.xla_device() + met.clear_all() + t1 = torch.tensor(156, device=xla_device) + t2 = t1 + 100 + xm.wait_device_ops() + self.assertIn('EagerOpCompileTime', met.metric_names()) + # one for cosntant, one for add + self.assertEqual(met.metric_data('EagerOpCompileTime')[0], 2) + self.assertIn('EagerOpExecuteTime', met.metric_names()) + # one for add + self.assertEqual(met.metric_data('EagerOpExecuteTime')[0], 2) + # mark_step should be a no-op + xm.mark_step() + self.assertNotIn('CompileTime', met.metric_names()) + self.assertNotIn('ExecuteTime', met.metric_names()) def test_short_metrics_report_default_list(self): xla_device = xm.xla_device()