diff --git a/test/test_operations.py b/test/test_operations.py index e4ee4c3e954..e1dad566536 100644 --- a/test/test_operations.py +++ b/test/test_operations.py @@ -2634,6 +2634,37 @@ def test_dlpack_pytorch_cuda_to_xla(self): t3_cuda.fill_(6) self.assertTrue(torch.allclose(xla_t3.cpu(), t3_cuda.cpu())) + @onlyIfTorchSupportsCUDA + @onlyIfPJRTDeviceIsCUDA + def test_dlpack_pytorch_cuda_to_xla_protocol_conversion(self): + # Unlike the test_dlpack_pytorch_cuda_to_xla, + # torch_cuda_tensor has attribute __dlpack__ and __dlpack_device__. + # From cuda tensors to xla tensors, the synchronization is handdled implicitly. + t1_cuda = torch.arange(5).cuda() + xla_t1 = xdlpack.from_dlpack(t1_cuda) + self.assertEqual(xla_t1.device.type, 'xla') + self.assertEqual(xla_t1.device.index, t1_cuda.device.index) + t1_cuda[0] = t1_cuda[0] + 20 + self.assertTrue(torch.allclose(xla_t1.cpu(), t1_cuda.cpu())) + + t2_cuda = torch.tensor(5).cuda() + xla_t2 = xdlpack.from_dlpack(t2_cuda) + self.assertEqual(xla_t2.device.type, 'xla') + self.assertEqual(xla_t2.device.index, t2_cuda.device.index) + t2_cuda.fill_(6) + self.assertTrue(torch.allclose(xla_t2.cpu(), t2_cuda.cpu())) + + cuda1 = torch.device('cuda:1') + t3_cuda = torch.tensor(5, device=cuda1) + xla_t3 = xdlpack.from_dlpack(t3_cuda) + self.assertEqual(xla_t3.device.type, 'xla') + self.assertEqual( + xla_t3.device.index, + t3_cuda.device.index, + msg='both value should 1. xla_t3.device should be xla:1.') + t3_cuda.fill_(6) + self.assertTrue(torch.allclose(xla_t3.cpu(), t3_cuda.cpu())) + @onlyIfTorchSupportsCUDA @onlyIfPJRTDeviceIsCUDA def test_dlpack_xla_to_pytorch_cuda(self): diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index 5d7707d5699..12e955adf38 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -1296,6 +1296,9 @@ void InitXlaModuleBindings(py::module m) { return runtime::GetComputationClient()->GetLocalDevices(); } }); + m.def("_get_stream_for_cuda_device", [](const int device_id) { + return runtime::GetComputationClient()->GetCudaStreamForDevice(device_id); + }); m.def("_xla_num_devices", []() -> int64_t { if (UseVirtualDevice()) { return 1; diff --git a/torch_xla/csrc/runtime/computation_client.h b/torch_xla/csrc/runtime/computation_client.h index 93664900ebd..a66ae2a7fa4 100644 --- a/torch_xla/csrc/runtime/computation_client.h +++ b/torch_xla/csrc/runtime/computation_client.h @@ -361,6 +361,8 @@ class ComputationClient { virtual absl::StatusOr LookupAddressableDevice( int local_device_id) const = 0; + virtual std::intptr_t GetCudaStreamForDevice(int local_device_id) const = 0; + virtual size_t GetNumDevices() const = 0; virtual std::vector GetLocalDevices() const = 0; diff --git a/torch_xla/csrc/runtime/ifrt_computation_client.h b/torch_xla/csrc/runtime/ifrt_computation_client.h index 6f71da8422f..59664d045e8 100644 --- a/torch_xla/csrc/runtime/ifrt_computation_client.h +++ b/torch_xla/csrc/runtime/ifrt_computation_client.h @@ -96,6 +96,10 @@ class IfrtComputationClient : public ComputationClient { XLA_ERROR() << __FUNCTION__ << " not implemented"; } + std::intptr_t GetCudaStreamForDevice(int local_device_id) const override { + XLA_ERROR() << __FUNCTION__ << " not implemented"; + } + std::vector GetLocalDevices() const override; std::vector GetAllDevices() const override; diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.h b/torch_xla/csrc/runtime/pjrt_computation_client.h index 5ed4326d283..1d31107e6b9 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.h +++ b/torch_xla/csrc/runtime/pjrt_computation_client.h @@ -104,6 +104,17 @@ class PjRtComputationClient : public ComputationClient { xla::PjRtLocalDeviceId(local_device_id)); } + std::intptr_t GetCudaStreamForDevice(int local_device_id) const override { + absl::StatusOr pjrt_device = + client_->LookupAddressableDevice( + xla::PjRtLocalDeviceId(local_device_id)); + XLA_CHECK(pjrt_device.ok()) << "Failed to get a PjRt device."; + absl::StatusOr stream = + pjrt_device.value()->GetStreamForExternalReadyEvents(); + XLA_CHECK(stream.ok()) << "Failed to get a stream."; + return stream.value(); + } + std::vector GetLocalDevices() const override; std::vector GetAllDevices() const override; diff --git a/torch_xla/utils/dlpack.py b/torch_xla/utils/dlpack.py index 9f93d532b27..c49083e4403 100644 --- a/torch_xla/utils/dlpack.py +++ b/torch_xla/utils/dlpack.py @@ -1,4 +1,5 @@ from typing import Any +import enum import torch_xla @@ -6,5 +7,30 @@ def to_dlpack(xla_tensor: Any): return torch_xla._XLAC._to_dlpack(xla_tensor) +class DLDeviceType(enum.IntEnum): + # Enums as in DLPack specification (aten/src/ATen/dlpack.h) + kDLCPU = 1, + kDLGPU = 2, + kDLCPUPinned = 3, + kDLOpenCL = 4, + kDLVulkan = 7, + kDLMetal = 8, + kDLVPI = 9, + kDLROCM = 10, + kDLExtDev = 12, + kDLOneAPI = 14, + + def from_dlpack(ext_tensor: Any): - return torch_xla._XLAC._from_dlpack(ext_tensor) + if hasattr(ext_tensor, '__dlpack_device__') and hasattr( + ext_tensor, '__dlpack__'): + device_type, device_id = ext_tensor.__dlpack_device__() + if device_type == DLDeviceType.kDLGPU: + stream = torch_xla._XLAC._get_stream_for_cuda_device(device_id) + dlpack = ext_tensor.__dlpack__(stream=stream) + else: + dlpack = ext_tensor.__dlpack__() + else: + dlpack = ext_tensor + + return torch_xla._XLAC._from_dlpack(dlpack)