diff --git a/test/run_tests.sh b/test/run_tests.sh index 29abac1bb2f..a082f64ca22 100755 --- a/test/run_tests.sh +++ b/test/run_tests.sh @@ -200,6 +200,8 @@ function run_xla_op_tests2 { run_test "$CDIR/eager/test_eager_with_torch_compile.py" run_test "$CDIR/eager/test_eager_all_reduce_in_place.py" run_test "$CDIR/eager/test_eager_spmd.py" + run_test "$CDIR/test_callback.py" + XLA_USE_SPMD=1 run_test "$CDIR/test_callback.py" } # All the new xla op tests should go to run_xla_op_tests3 diff --git a/test/test_callback.py b/test/test_callback.py new file mode 100644 index 00000000000..09c443c504b --- /dev/null +++ b/test/test_callback.py @@ -0,0 +1,34 @@ +import threading + +from absl.testing import absltest +import torch +import torch_xla +from torch_xla.experimental import callback + + +class TestExperimentalCallback(absltest.TestCase): + + @staticmethod + @torch_xla.compile + def executable(): + a, b = torch.randn((100, 100), device=torch_xla.device()), torch.randn( + (100, 100), device=torch_xla.device()) + return a @ b + + def test_callback(self): + event = threading.Event() + c = self.executable() + + def cb(tensor): + self.assertIs(c, tensor) + # TODO: check that result is both assigned and completed + self.assertNotIn("Data Handle: None", + torch_xla._XLAC._get_xla_tensor_debug_info(tensor)) + event.set() + + callback.on_ready_callback(c, cb) + event.wait(3) + + +if __name__ == "__main__": + absltest.main() diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index 6dc798ba4d3..4768ebfd2bc 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -27,6 +27,7 @@ #include "pybind11/attr.h" #include "pybind11/cast.h" #include "pybind11/detail/common.h" +#include "pybind11/functional.h" #include "pybind11/numpy.h" #include "pybind11/pybind11.h" #include "pybind11/pytypes.h" @@ -133,7 +134,7 @@ void PrepareToExit() { if (client != nullptr) { auto xla_device = GetDeviceOrCurrent(""); SetAllReduceToken(xla_device, nullptr); - XLAGraphExecutor::Get()->WaitDeviceOps({}); + WaitDeviceOps(); } } @@ -2619,6 +2620,29 @@ void InitXlaModuleBindings(py::module m) { return false; }); + m.def("_on_ready_callback", + [](const at::Tensor& tensor, const std::function& callback) { + XLATensorPtr xtensor = bridge::GetXlaTensor(tensor); + XLA_CHECK(xtensor) << "The input is not an XLA tensor."; + // Wait for placeholder `Data`s to be assigned + XLAGraphExecutor::Get()->WaitDeviceOps({}); + std::shared_ptr data; + if (xtensor->CurrentDataHandle() != nullptr) { + data = UnwrapXlaData(xtensor->CurrentDataHandle()); + } else if (xtensor->CurrentIrValue().node != nullptr) { + DeviceData* device_data = + DeviceData::Cast(xtensor->CurrentIrValue().node.get()); + if (device_data != nullptr) { + data = UnwrapXlaData(device_data->data()); + } else { + XLA_ERROR() << "Could not get the buffer pointer for XLATensor " + "with IR that's not DeviceData"; + } + XLA_ERROR() << "Could not get buffer for tensor"; + } + runtime::GetComputationClient()->OnReadyCallback(data, callback); + }); + m.def("_unsafe_buffer_pointer", [](const at::Tensor& input) -> std::uintptr_t { XLATensorPtr xtensor = bridge::GetXlaTensor(input); @@ -2646,9 +2670,9 @@ void InitXlaModuleBindings(py::module m) { // from an XLA tensor to a PyCapsule. // When consuming the PyCapsule, we should synchronize - // (waits for all kernels in all streams on a CUDA device to complete) if the - // current stream is different from the ext_data's stream. Otherwise, we may - // risk of getting incorrect results. + // (waits for all kernels in all streams on a CUDA device to complete) if + // the current stream is different from the ext_data's stream. Otherwise, we + // may risk of getting incorrect results. m.def("_to_dlpack", [](const at::Tensor& input) -> py::handle { DLManagedTensor* dlMTensor; { @@ -2660,9 +2684,9 @@ void InitXlaModuleBindings(py::module m) { // from a dlpack PyCapsule to an XLA tensor // If ext_data is the result of an CUDA computation, we should synchronize - // (waits for all kernels in all streams on a CUDA device to complete) if the - // current stream is different from the ext_data's stream. Otherwise, we may - // risk of getting incorrect results. Or you can use torch_xla's + // (waits for all kernels in all streams on a CUDA device to complete) if + // the current stream is different from the ext_data's stream. Otherwise, we + // may risk of getting incorrect results. Or you can use torch_xla's // from_dlpack(cuda_tensor) and it will handle the synchronization for you. m.def("_from_dlpack", [](py::handle ext_data) -> at::Tensor { return tensor_fromDLPack(ext_data.ptr()); diff --git a/torch_xla/csrc/runtime/computation_client.h b/torch_xla/csrc/runtime/computation_client.h index 0c84ea12bc6..1b2819656cf 100644 --- a/torch_xla/csrc/runtime/computation_client.h +++ b/torch_xla/csrc/runtime/computation_client.h @@ -409,6 +409,10 @@ class ComputationClient { void* function_ptr, const std::string& platform) = 0; + // Installs a callback to be called when the buffer backing `data` is ready. + virtual void OnReadyCallback(DataPtr data, + const std::function& callback) = 0; + // Utility API around the vector based Compile() API to compile a single // computation. ComputationPtr Compile(xla::XlaComputation computation, diff --git a/torch_xla/csrc/runtime/ifrt_computation_client.h b/torch_xla/csrc/runtime/ifrt_computation_client.h index a0492e0a728..ec8c314bc3d 100644 --- a/torch_xla/csrc/runtime/ifrt_computation_client.h +++ b/torch_xla/csrc/runtime/ifrt_computation_client.h @@ -155,6 +155,11 @@ class IfrtComputationClient : public ComputationClient { XLA_ERROR() << __FUNCTION__ << " not implemented"; }; + void OnReadyCallback(DataPtr data, + const std::function& callback) override { + XLA_ERROR() << __FUNCTION__ << " not implemented"; + } + private: std::shared_ptr client_; std::unique_ptr coordinator_; diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.cc b/torch_xla/csrc/runtime/pjrt_computation_client.cc index f3c2873ff37..280a733bebe 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.cc +++ b/torch_xla/csrc/runtime/pjrt_computation_client.cc @@ -1033,5 +1033,22 @@ void PjRtComputationClient::RegisterCustomCall(const std::string& fn_name, } } +void PjRtComputationClient::OnReadyCallback( + ComputationClient::DataPtr data, const std::function& callback) { + std::shared_ptr buffer; + if (auto pjrt_data = std::dynamic_pointer_cast(data)) { + buffer = pjrt_data->buffer; + } else if (auto sharded_data = + std::dynamic_pointer_cast(data)) { + XLA_CHECK(sharded_data->shards.size()) << "sharded data has no shards"; + buffer = sharded_data->shards[0]->buffer; + } else { + XLA_ERROR() << "received invalid data pointer"; + } + XLA_CHECK(buffer) << "received placeholder data as argument"; + buffer->GetReadyFuture().OnReady( + [callback](absl::Status unused) { callback(); }); +} + } // namespace runtime } // namespace torch_xla diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.h b/torch_xla/csrc/runtime/pjrt_computation_client.h index 5eb4ff8fdd3..ca2257f8295 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.h +++ b/torch_xla/csrc/runtime/pjrt_computation_client.h @@ -156,6 +156,9 @@ class PjRtComputationClient : public ComputationClient { void RegisterCustomCall(const std::string& fn_name, void* function_ptr, const std::string& platform) override; + void OnReadyCallback(DataPtr data, + const std::function& callback) override; + private: std::unique_ptr client_; std::unique_ptr coordinator_; diff --git a/torch_xla/experimental/callback.py b/torch_xla/experimental/callback.py new file mode 100644 index 00000000000..363620f7867 --- /dev/null +++ b/torch_xla/experimental/callback.py @@ -0,0 +1,17 @@ +from typing import Callable +import torch +import torch_xla + + +def on_ready_callback(tensor, callback: Callable[[torch.Tensor], None]): + """Installs callback on `tensor` to be called when underlying buffer is ready. + + Note: Since `callback` will need to re-acquire the GIL since it is a Python + callable. If the main thread is blocking on `callback` and holding the GIL, + this will result in a deadlock. + """ + + def _callback_wrapper(): + callback(tensor) + + torch_xla._XLAC._on_ready_callback(tensor, _callback_wrapper)