diff --git a/test/test_callback.py b/test/test_callback.py index 09c443c504b..242fef6443e 100644 --- a/test/test_callback.py +++ b/test/test_callback.py @@ -29,6 +29,13 @@ def cb(tensor): callback.on_ready_callback(c, cb) event.wait(3) + def test_callback_event(self): + c = self.executable() + c_ready_event = callback.on_ready_event(c) + c_ready_event.wait(3) + self.assertNotIn("Data Handle: None", + torch_xla._XLAC._get_xla_tensor_debug_info(c)) + if __name__ == "__main__": absltest.main() diff --git a/torch_xla/experimental/callback.py b/torch_xla/experimental/callback.py index 363620f7867..93152390acc 100644 --- a/torch_xla/experimental/callback.py +++ b/torch_xla/experimental/callback.py @@ -1,6 +1,7 @@ from typing import Callable import torch import torch_xla +import threading def on_ready_callback(tensor, callback: Callable[[torch.Tensor], None]): @@ -15,3 +16,19 @@ def _callback_wrapper(): callback(tensor) torch_xla._XLAC._on_ready_callback(tensor, _callback_wrapper) + + +def on_ready_event(tensor: torch.Tensor) -> threading.Event: + """Return a python threading.event that will be set once underlying + tensor buffer is ready. + + Args: + tensor: tensor that the event will be blocked on + """ + ready_event = threading.Event() + + def _callback_wrapper(): + ready_event.set() + + torch_xla._XLAC._on_ready_callback(tensor, _callback_wrapper) + return ready_event