Skip to content

Commit

Permalink
Add on_ready_event (#7984)
Browse files Browse the repository at this point in the history
  • Loading branch information
JackCaoG committed Sep 10, 2024
1 parent 6ff03e6 commit 8888217
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 0 deletions.
7 changes: 7 additions & 0 deletions test/test_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,13 @@ def cb(tensor):
callback.on_ready_callback(c, cb)
event.wait(3)

def test_callback_event(self):
c = self.executable()
c_ready_event = callback.on_ready_event(c)
c_ready_event.wait(3)
self.assertNotIn("Data Handle: None",
torch_xla._XLAC._get_xla_tensor_debug_info(c))


if __name__ == "__main__":
absltest.main()
17 changes: 17 additions & 0 deletions torch_xla/experimental/callback.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Callable
import torch
import torch_xla
import threading


def on_ready_callback(tensor, callback: Callable[[torch.Tensor], None]):
Expand All @@ -15,3 +16,19 @@ def _callback_wrapper():
callback(tensor)

torch_xla._XLAC._on_ready_callback(tensor, _callback_wrapper)


def on_ready_event(tensor: torch.Tensor) -> threading.Event:
"""Return a python threading.event that will be set once underlying
tensor buffer is ready.
Args:
tensor: tensor that the event will be blocked on
"""
ready_event = threading.Event()

def _callback_wrapper():
ready_event.set()

torch_xla._XLAC._on_ready_callback(tensor, _callback_wrapper)
return ready_event

0 comments on commit 8888217

Please sign in to comment.