Skip to content

Commit

Permalink
Make from_dlpack handle cuda synchronization implicitly for input ten…
Browse files Browse the repository at this point in the history
…sors that have __dlpack__ and __dlpack_device__ attributes. (#7125)
  • Loading branch information
vanbasten23 authored May 30, 2024
1 parent aeed61a commit daada22
Show file tree
Hide file tree
Showing 6 changed files with 78 additions and 1 deletion.
31 changes: 31 additions & 0 deletions test/test_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -2634,6 +2634,37 @@ def test_dlpack_pytorch_cuda_to_xla(self):
t3_cuda.fill_(6)
self.assertTrue(torch.allclose(xla_t3.cpu(), t3_cuda.cpu()))

@onlyIfTorchSupportsCUDA
@onlyIfPJRTDeviceIsCUDA
def test_dlpack_pytorch_cuda_to_xla_protocol_conversion(self):
# Unlike the test_dlpack_pytorch_cuda_to_xla,
# torch_cuda_tensor has attribute __dlpack__ and __dlpack_device__.
# From cuda tensors to xla tensors, the synchronization is handdled implicitly.
t1_cuda = torch.arange(5).cuda()
xla_t1 = xdlpack.from_dlpack(t1_cuda)
self.assertEqual(xla_t1.device.type, 'xla')
self.assertEqual(xla_t1.device.index, t1_cuda.device.index)
t1_cuda[0] = t1_cuda[0] + 20
self.assertTrue(torch.allclose(xla_t1.cpu(), t1_cuda.cpu()))

t2_cuda = torch.tensor(5).cuda()
xla_t2 = xdlpack.from_dlpack(t2_cuda)
self.assertEqual(xla_t2.device.type, 'xla')
self.assertEqual(xla_t2.device.index, t2_cuda.device.index)
t2_cuda.fill_(6)
self.assertTrue(torch.allclose(xla_t2.cpu(), t2_cuda.cpu()))

cuda1 = torch.device('cuda:1')
t3_cuda = torch.tensor(5, device=cuda1)
xla_t3 = xdlpack.from_dlpack(t3_cuda)
self.assertEqual(xla_t3.device.type, 'xla')
self.assertEqual(
xla_t3.device.index,
t3_cuda.device.index,
msg='both value should 1. xla_t3.device should be xla:1.')
t3_cuda.fill_(6)
self.assertTrue(torch.allclose(xla_t3.cpu(), t3_cuda.cpu()))

@onlyIfTorchSupportsCUDA
@onlyIfPJRTDeviceIsCUDA
def test_dlpack_xla_to_pytorch_cuda(self):
Expand Down
3 changes: 3 additions & 0 deletions torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1296,6 +1296,9 @@ void InitXlaModuleBindings(py::module m) {
return runtime::GetComputationClient()->GetLocalDevices();
}
});
m.def("_get_stream_for_cuda_device", [](const int device_id) {
return runtime::GetComputationClient()->GetCudaStreamForDevice(device_id);
});
m.def("_xla_num_devices", []() -> int64_t {
if (UseVirtualDevice()) {
return 1;
Expand Down
2 changes: 2 additions & 0 deletions torch_xla/csrc/runtime/computation_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,8 @@ class ComputationClient {
virtual absl::StatusOr<xla::PjRtDevice*> LookupAddressableDevice(
int local_device_id) const = 0;

virtual std::intptr_t GetCudaStreamForDevice(int local_device_id) const = 0;

virtual size_t GetNumDevices() const = 0;

virtual std::vector<std::string> GetLocalDevices() const = 0;
Expand Down
4 changes: 4 additions & 0 deletions torch_xla/csrc/runtime/ifrt_computation_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,10 @@ class IfrtComputationClient : public ComputationClient {
XLA_ERROR() << __FUNCTION__ << " not implemented";
}

std::intptr_t GetCudaStreamForDevice(int local_device_id) const override {
XLA_ERROR() << __FUNCTION__ << " not implemented";
}

std::vector<std::string> GetLocalDevices() const override;

std::vector<std::string> GetAllDevices() const override;
Expand Down
11 changes: 11 additions & 0 deletions torch_xla/csrc/runtime/pjrt_computation_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,17 @@ class PjRtComputationClient : public ComputationClient {
xla::PjRtLocalDeviceId(local_device_id));
}

std::intptr_t GetCudaStreamForDevice(int local_device_id) const override {
absl::StatusOr<xla::PjRtDevice*> pjrt_device =
client_->LookupAddressableDevice(
xla::PjRtLocalDeviceId(local_device_id));
XLA_CHECK(pjrt_device.ok()) << "Failed to get a PjRt device.";
absl::StatusOr<std::intptr_t> stream =
pjrt_device.value()->GetStreamForExternalReadyEvents();
XLA_CHECK(stream.ok()) << "Failed to get a stream.";
return stream.value();
}

std::vector<std::string> GetLocalDevices() const override;

std::vector<std::string> GetAllDevices() const override;
Expand Down
28 changes: 27 additions & 1 deletion torch_xla/utils/dlpack.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,36 @@
from typing import Any
import enum
import torch_xla


def to_dlpack(xla_tensor: Any):
return torch_xla._XLAC._to_dlpack(xla_tensor)


class DLDeviceType(enum.IntEnum):
# Enums as in DLPack specification (aten/src/ATen/dlpack.h)
kDLCPU = 1,
kDLGPU = 2,
kDLCPUPinned = 3,
kDLOpenCL = 4,
kDLVulkan = 7,
kDLMetal = 8,
kDLVPI = 9,
kDLROCM = 10,
kDLExtDev = 12,
kDLOneAPI = 14,


def from_dlpack(ext_tensor: Any):
return torch_xla._XLAC._from_dlpack(ext_tensor)
if hasattr(ext_tensor, '__dlpack_device__') and hasattr(
ext_tensor, '__dlpack__'):
device_type, device_id = ext_tensor.__dlpack_device__()
if device_type == DLDeviceType.kDLGPU:
stream = torch_xla._XLAC._get_stream_for_cuda_device(device_id)
dlpack = ext_tensor.__dlpack__(stream=stream)
else:
dlpack = ext_tensor.__dlpack__()
else:
dlpack = ext_tensor

return torch_xla._XLAC._from_dlpack(dlpack)

0 comments on commit daada22

Please sign in to comment.