Skip to content

Commit

Permalink
Attach callback to a Tensor's underlying buffer (#7793)
Browse files Browse the repository at this point in the history
  • Loading branch information
will-cromar committed Aug 5, 2024
1 parent 5bbc4b3 commit 02d8322
Show file tree
Hide file tree
Showing 8 changed files with 113 additions and 7 deletions.
2 changes: 2 additions & 0 deletions test/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,8 @@ function run_xla_op_tests2 {
run_test "$CDIR/eager/test_eager_with_torch_compile.py"
run_test "$CDIR/eager/test_eager_all_reduce_in_place.py"
run_test "$CDIR/eager/test_eager_spmd.py"
run_test "$CDIR/test_callback.py"
XLA_USE_SPMD=1 run_test "$CDIR/test_callback.py"
}

# All the new xla op tests should go to run_xla_op_tests3
Expand Down
34 changes: 34 additions & 0 deletions test/test_callback.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import threading

from absl.testing import absltest
import torch
import torch_xla
from torch_xla.experimental import callback


class TestExperimentalCallback(absltest.TestCase):

@staticmethod
@torch_xla.compile
def executable():
a, b = torch.randn((100, 100), device=torch_xla.device()), torch.randn(
(100, 100), device=torch_xla.device())
return a @ b

def test_callback(self):
event = threading.Event()
c = self.executable()

def cb(tensor):
self.assertIs(c, tensor)
# TODO: check that result is both assigned and completed
self.assertNotIn("Data Handle: None",
torch_xla._XLAC._get_xla_tensor_debug_info(tensor))
event.set()

callback.on_ready_callback(c, cb)
event.wait(3)


if __name__ == "__main__":
absltest.main()
38 changes: 31 additions & 7 deletions torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include "pybind11/attr.h"
#include "pybind11/cast.h"
#include "pybind11/detail/common.h"
#include "pybind11/functional.h"
#include "pybind11/numpy.h"
#include "pybind11/pybind11.h"
#include "pybind11/pytypes.h"
Expand Down Expand Up @@ -133,7 +134,7 @@ void PrepareToExit() {
if (client != nullptr) {
auto xla_device = GetDeviceOrCurrent("");
SetAllReduceToken(xla_device, nullptr);
XLAGraphExecutor::Get()->WaitDeviceOps({});
WaitDeviceOps();
}
}

Expand Down Expand Up @@ -2619,6 +2620,29 @@ void InitXlaModuleBindings(py::module m) {
return false;
});

m.def("_on_ready_callback",
[](const at::Tensor& tensor, const std::function<void()>& callback) {
XLATensorPtr xtensor = bridge::GetXlaTensor(tensor);
XLA_CHECK(xtensor) << "The input is not an XLA tensor.";
// Wait for placeholder `Data`s to be assigned
XLAGraphExecutor::Get()->WaitDeviceOps({});
std::shared_ptr<runtime::ComputationClient::Data> data;
if (xtensor->CurrentDataHandle() != nullptr) {
data = UnwrapXlaData(xtensor->CurrentDataHandle());
} else if (xtensor->CurrentIrValue().node != nullptr) {
DeviceData* device_data =
DeviceData::Cast(xtensor->CurrentIrValue().node.get());
if (device_data != nullptr) {
data = UnwrapXlaData(device_data->data());
} else {
XLA_ERROR() << "Could not get the buffer pointer for XLATensor "
"with IR that's not DeviceData";
}
XLA_ERROR() << "Could not get buffer for tensor";
}
runtime::GetComputationClient()->OnReadyCallback(data, callback);
});

m.def("_unsafe_buffer_pointer",
[](const at::Tensor& input) -> std::uintptr_t {
XLATensorPtr xtensor = bridge::GetXlaTensor(input);
Expand Down Expand Up @@ -2646,9 +2670,9 @@ void InitXlaModuleBindings(py::module m) {

// from an XLA tensor to a PyCapsule.
// When consuming the PyCapsule, we should synchronize
// (waits for all kernels in all streams on a CUDA device to complete) if the
// current stream is different from the ext_data's stream. Otherwise, we may
// risk of getting incorrect results.
// (waits for all kernels in all streams on a CUDA device to complete) if
// the current stream is different from the ext_data's stream. Otherwise, we
// may risk of getting incorrect results.
m.def("_to_dlpack", [](const at::Tensor& input) -> py::handle {
DLManagedTensor* dlMTensor;
{
Expand All @@ -2660,9 +2684,9 @@ void InitXlaModuleBindings(py::module m) {

// from a dlpack PyCapsule to an XLA tensor
// If ext_data is the result of an CUDA computation, we should synchronize
// (waits for all kernels in all streams on a CUDA device to complete) if the
// current stream is different from the ext_data's stream. Otherwise, we may
// risk of getting incorrect results. Or you can use torch_xla's
// (waits for all kernels in all streams on a CUDA device to complete) if
// the current stream is different from the ext_data's stream. Otherwise, we
// may risk of getting incorrect results. Or you can use torch_xla's
// from_dlpack(cuda_tensor) and it will handle the synchronization for you.
m.def("_from_dlpack", [](py::handle ext_data) -> at::Tensor {
return tensor_fromDLPack(ext_data.ptr());
Expand Down
4 changes: 4 additions & 0 deletions torch_xla/csrc/runtime/computation_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -409,6 +409,10 @@ class ComputationClient {
void* function_ptr,
const std::string& platform) = 0;

// Installs a callback to be called when the buffer backing `data` is ready.
virtual void OnReadyCallback(DataPtr data,
const std::function<void()>& callback) = 0;

// Utility API around the vector based Compile() API to compile a single
// computation.
ComputationPtr Compile(xla::XlaComputation computation,
Expand Down
5 changes: 5 additions & 0 deletions torch_xla/csrc/runtime/ifrt_computation_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,11 @@ class IfrtComputationClient : public ComputationClient {
XLA_ERROR() << __FUNCTION__ << " not implemented";
};

void OnReadyCallback(DataPtr data,
const std::function<void()>& callback) override {
XLA_ERROR() << __FUNCTION__ << " not implemented";
}

private:
std::shared_ptr<xla::ifrt::PjRtClient> client_;
std::unique_ptr<XlaCoordinator> coordinator_;
Expand Down
17 changes: 17 additions & 0 deletions torch_xla/csrc/runtime/pjrt_computation_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1033,5 +1033,22 @@ void PjRtComputationClient::RegisterCustomCall(const std::string& fn_name,
}
}

void PjRtComputationClient::OnReadyCallback(
ComputationClient::DataPtr data, const std::function<void()>& callback) {
std::shared_ptr<xla::PjRtBuffer> buffer;
if (auto pjrt_data = std::dynamic_pointer_cast<PjRtData>(data)) {
buffer = pjrt_data->buffer;
} else if (auto sharded_data =
std::dynamic_pointer_cast<PjRtShardedData>(data)) {
XLA_CHECK(sharded_data->shards.size()) << "sharded data has no shards";
buffer = sharded_data->shards[0]->buffer;
} else {
XLA_ERROR() << "received invalid data pointer";
}
XLA_CHECK(buffer) << "received placeholder data as argument";
buffer->GetReadyFuture().OnReady(
[callback](absl::Status unused) { callback(); });
}

} // namespace runtime
} // namespace torch_xla
3 changes: 3 additions & 0 deletions torch_xla/csrc/runtime/pjrt_computation_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,9 @@ class PjRtComputationClient : public ComputationClient {
void RegisterCustomCall(const std::string& fn_name, void* function_ptr,
const std::string& platform) override;

void OnReadyCallback(DataPtr data,
const std::function<void()>& callback) override;

private:
std::unique_ptr<xla::PjRtClient> client_;
std::unique_ptr<XlaCoordinator> coordinator_;
Expand Down
17 changes: 17 additions & 0 deletions torch_xla/experimental/callback.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from typing import Callable
import torch
import torch_xla


def on_ready_callback(tensor, callback: Callable[[torch.Tensor], None]):
"""Installs callback on `tensor` to be called when underlying buffer is ready.
Note: Since `callback` will need to re-acquire the GIL since it is a Python
callable. If the main thread is blocking on `callback` and holding the GIL,
this will result in a deadlock.
"""

def _callback_wrapper():
callback(tensor)

torch_xla._XLAC._on_ready_callback(tensor, _callback_wrapper)

0 comments on commit 02d8322

Please sign in to comment.