Skip to content

Commit

Permalink
make StepMarker's wait block for async device execution (#7794)
Browse files Browse the repository at this point in the history
  • Loading branch information
JackCaoG committed Aug 2, 2024
1 parent 7fe070a commit dfd7b00
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 0 deletions.
12 changes: 12 additions & 0 deletions test/test_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -2701,6 +2701,18 @@ def test_as_strided_input_larger(self):

self.assertEqual(a, former_a)

@skipOnEagerDebug
def test_sync_wait(self):
xm.wait_device_ops()
met.clear_all()
device = torch_xla.device()
input = torch.randn(1024, 1024, device=device)
res = input @ input @ input

torch_xla.sync(wait=True)
# ExecuteTime will show up after the async device execution finished.
self.assertIn('ExecuteTime', met.metric_names())

def _test_move_tensor_cuda_to_xla(self, cpu_tensor):
# Assumes CPU-XLA data movement works.
cuda_tensor = cpu_tensor.to("cuda")
Expand Down
3 changes: 3 additions & 0 deletions torch_xla/csrc/xla_graph_executor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -410,6 +410,9 @@ void XLAGraphExecutor::SyncTensorsGraph(std::vector<XLATensorPtr>* tensors,
SyncTensorsGraphInternal(tensors, devices, config, warm_up_cache_only);
if (wait && async != nullptr && !warm_up_cache_only) {
async->mwait.Wait();
// async->mwait.Wait() will block until the async computation thread to
// return but the real device execution might not finish.
runtime::GetComputationClient()->WaitDeviceOps(devices);
}
}

Expand Down

0 comments on commit dfd7b00

Please sign in to comment.