From 160a4d2f9c0c4f2dbb5ee0a0d0894693570dbac9 Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Wed, 17 Jul 2024 18:52:56 +0000 Subject: [PATCH 1/2] [WIP] Use `privateuseone` dispatch key --- experimental/torch_xla2/pyproject.toml | 2 + .../torch_xla2/cpp/registration.cpp | 280 ++++++++++++++++++ .../torch_xla2/torch_xla2/custom_device.py | 26 ++ .../torch_xla2/torch_xla2/ops/jaten.py | 22 ++ 4 files changed, 330 insertions(+) create mode 100644 experimental/torch_xla2/torch_xla2/cpp/registration.cpp create mode 100644 experimental/torch_xla2/torch_xla2/custom_device.py diff --git a/experimental/torch_xla2/pyproject.toml b/experimental/torch_xla2/pyproject.toml index d20dc135d23..eb86d75dcea 100644 --- a/experimental/torch_xla2/pyproject.toml +++ b/experimental/torch_xla2/pyproject.toml @@ -12,6 +12,8 @@ dependencies = [ "tensorflow-cpu", # Developers should install `dev-requirements.txt` first "torch>=2.3.0", + # TODO: needed for JIT compiling C++ module + "ninja", ] requires-python = ">=3.10" license = {file = "LICENSE"} diff --git a/experimental/torch_xla2/torch_xla2/cpp/registration.cpp b/experimental/torch_xla2/torch_xla2/cpp/registration.cpp new file mode 100644 index 00000000000..0648d77af5d --- /dev/null +++ b/experimental/torch_xla2/torch_xla2/cpp/registration.cpp @@ -0,0 +1,280 @@ +#include +#include +#include +#include + +#include +#include + +#include +#include +#include + +#include + +// This file contains the heavy lifting to add a new C++ backend +// and integrate it directly into the PyTorch backend. It mainly involves: +// +// (1) Writing a custom allocator and registering it to pytorch +// (see DummyCustomAllocator) +// (2) Writing a custom device guard, registering it to pytorch, +// and using the device guard in kernels +// (see DummyDeviceGuard) +// (3) Writing a custom aten::empty.memory_format function + + +// basic dummy add function +// at::Tensor custom_add_Tensor(const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha) { +// const at::OptionalDeviceGuard device_guard(at::device_of(self)); +// std::cout << "Custom aten::add.Tensor() called!" << std::endl; +// // Since this custom device is just for testing, not bothering to implement kernels. +// return at::empty(self.sizes(), self.options()); +// } + +// ===================================== +// ========= Custom Allocators ========= +// ===================================== + +// PyTorch provides an API for registering custom allocators for your device. +// You can create one by inheriting from the at::Allocator class, +// and registering your allocator for the particular device type +// (PrivateUse1 for open registration devices) + +// A dummy allocator for our custom device, that secretly uses the CPU +// struct DummyCustomAllocator final : at::Allocator { +// DummyCustomAllocator() = default; +// at::DataPtr allocate(size_t nbytes) const override { +// std::cout << "Custom allocator's allocate() called!" << std::endl; +// void* data = c10::alloc_cpu(nbytes); +// return {data, data, &ReportAndDelete, at::Device(at::DeviceType::PrivateUse1, 0)}; +// } + +// static void ReportAndDelete(void* ptr) { +// if (!ptr) { +// return; +// } +// std::cout << "Custom allocator's delete() called!" << std::endl; +// c10::free_cpu(ptr); +// } + +// at::DeleterFnPtr raw_deleter() const override { +// return &ReportAndDelete; +// } +// }; + +// // Register our dummy allocator +// static DummyCustomAllocator global_custom_alloc; +// REGISTER_ALLOCATOR(c10::DeviceType::PrivateUse1, &global_custom_alloc); + +// ===================================== +// ============= Device Guards ========= +// ===================================== + +// PyTorch has an API for registering device guards. +// Device guards can be used to set the current "active" device, +// and e.g. error if the user provides an invalid device index. +// +// If your device doesn't support indices (e.g. foo:0 vs. foo:1), +// then the guards probably aren't needed. +// +// You can use it by creating a DeviceGuard class, registering it +// in PyTorch, and invoking the device guard before any kernels are called. +// For a more full-featured example of a device guard, +// check out the code at c10/cuda/CUDAGuard.h + +// Represents the current "active" device. +// The dummy device guard registered below is meant to show how a backend +// can integrate custom device guard with pytorch. +// For something like cuda this represents the current active cuda device, +// which is directly set using the cuda API calls cudaGetDevice/cudaSetDevice. +// static uint16_t CURR_DEVICE = -1; + +// Create and register a dummy device guard. +// struct DummyDeviceGuardImpl final : public c10::impl::DeviceGuardImplInterface { +// static constexpr c10::DeviceType static_type = c10::DeviceType::PrivateUse1; +// DummyDeviceGuardImpl() {} +// explicit DummyDeviceGuardImpl(c10::DeviceType t) { +// TORCH_INTERNAL_ASSERT(t == c10::DeviceType::PrivateUse1); +// } +// at::DeviceType type() const override { +// return at::DeviceType::PrivateUse1; +// } +// at::Device exchangeDevice(at::Device d) const override { +// TORCH_INTERNAL_ASSERT(d.type() == at::DeviceType::PrivateUse1); +// TORCH_INTERNAL_ASSERT(d.index() < deviceCount(), "Error: device index ", d.index(), " does not exist."); +// at::Device old_device = getDevice(); +// if (old_device.index() != d.index()) { +// // "set the active device" +// CURR_DEVICE = d.index(); +// } +// return old_device; +// } +// at::Device getDevice() const override { +// return at::Device(at::DeviceType::PrivateUse1, CURR_DEVICE); +// } +// void setDevice(at::Device d) const override { +// TORCH_INTERNAL_ASSERT(d.type() == at::DeviceType::PrivateUse1); +// TORCH_INTERNAL_ASSERT(d.index() < deviceCount(), "Error: device index ", d.index(), " does not exist."); +// at::Device current_device = getDevice(); +// if (current_device != d) { +// CURR_DEVICE = d.index(); +// } +// } +// void uncheckedSetDevice(at::Device d) const noexcept override { +// auto current_device = getDevice(); +// if (current_device != d) { +// CURR_DEVICE = d.index(); +// } +// } +// at::Stream getStream(at::Device d) const noexcept override { +// // no-op +// return at::Stream(at::Stream::DEFAULT, d); +// } +// // NB: These do NOT set the current device +// at::Stream exchangeStream(at::Stream) const noexcept override { +// // no-op +// return at::Stream(at::Stream::DEFAULT, at::Device(at::DeviceType::PrivateUse1, CURR_DEVICE)); +// } +// at::DeviceIndex deviceCount() const noexcept override { +// // Hardcoding the number of "valid" devices here at 2. +// return 2; +// } + +// // Event-related functions +// void record( +// void** /*event*/, +// const at::Stream& /*stream*/, +// const at::DeviceIndex /*device_index*/, +// const c10::EventFlag /*flag*/) const override { +// TORCH_CHECK(false, at::DeviceType::PrivateUse1, " backend doesn't support events."); +// } +// void block(void* /*event*/, const at::Stream& /*stream*/) const override { +// TORCH_CHECK(false, at::DeviceType::PrivateUse1, " backend doesn't support events.") +// } +// bool queryEvent(void* /*event*/) const override { +// TORCH_CHECK(false, at::DeviceType::PrivateUse1, " backend doesn't support events.") +// } +// void destroyEvent(void* /*event*/, const at::DeviceIndex /*device_index*/) +// const noexcept override {} + +// // Stream-related functions +// bool queryStream(const at::Stream& /*stream*/) const override { +// return true; +// } +// void synchronizeStream(const at::Stream& /*stream*/) const override { +// // Don't wait for anything. +// } +// }; + +// struct DummyGuard { +// explicit DummyGuard() = delete; +// explicit DummyGuard(at::DeviceIndex device_index) : guard_(device_index) {} +// explicit DummyGuard(at::Device device) : guard_(device) {} +// DummyGuard(const DummyGuard&) = delete; +// DummyGuard& operator=(const DummyGuard&) = delete; +// DummyGuard(DummyGuard&& other) = delete; +// DummyGuard& operator=(DummyGuard&& other) = delete; + +// void set_device(at::Device device) { +// guard_.set_device(device); +// } + +// void reset_device(at::Device device) { +// guard_.reset_device(device); +// } + +// void set_index(at::DeviceIndex device_index) { +// guard_.set_index(device_index); +// } + +// at::Device original_device() const { +// return guard_.original_device(); +// } + +// at::Device current_device() const { +// return guard_.current_device(); +// } + +// private: +// c10::impl::InlineDeviceGuard guard_; +// }; + +// C10_REGISTER_GUARD_IMPL(PrivateUse1, DummyDeviceGuardImpl); + + +// ===================================== +// ============= KERNELS =============== +// ===================================== + +// basic dummy empty function, so we can directly construct tensors on the custom device +// This dummy test device will just use the CPU allocator, and ignores pinned memory. +// +// Note: this kernel is very simple because our "custom device" just uses the normal TensorImpl object +// to store data under the hood. +// In PyTorch core today, both cpu and cuda are implemented with an ordinary TensorImpl class. +// Sometimes, backends prefer to subclass TensorImpl in order to store extra information. +// If this is the case, then this kernel is where you'll be responsible for creating and returning +// a fresh at::Tensor object, that properly stores a TensorImpl of your subclass. +// at::Tensor custom_empty_memory_format(at::IntArrayRef size, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory, c10::optional memory_format) { +// const at::OptionalDeviceGuard device_guard(device); +// std::cout << "Custom aten::empty.memory_format() called!" << std::endl; +// constexpr c10::DispatchKeySet private_use_ks(c10::DispatchKey::PrivateUse1); +// return at::detail::empty_generic(size, &global_custom_alloc, private_use_ks, c10::dtype_or_default(dtype), memory_format); +// } + +// at::Tensor & custom_fill__scalar(at::Tensor & self, const at::Scalar & value) { +// const at::OptionalDeviceGuard device_guard(at::device_of(self)); +// // Not bothering to implement. +// // Should fill the tensor's data with "value". +// return self; +// } + +// // basic dummy copy_() function, so we can copy from the custom device to/from CPU +// at::Tensor custom__copy_from(const at::Tensor& self, const at::Tensor& dst, bool non_blocking) { +// const at::OptionalDeviceGuard device_guard(at::device_of(self)); +// std::cout << "Custom aten::_copy_from() called!" << std::endl; +// TORCH_CHECK(self.is_cpu() || self.device().type() == c10::DeviceType::PrivateUse1, "Dummy test only allows copy from cpu -> dummy device."); +// TORCH_CHECK(dst.is_cpu() || dst.device().type() == c10::DeviceType::PrivateUse1, "Dummy test only allows copy from cpu -> dummy device."); + +// // Some dummy asserts for the basic use case: inputs are the same size / dtype, all contiguous. +// TORCH_CHECK(self.sizes() == dst.sizes()); +// TORCH_CHECK(self.scalar_type() == dst.scalar_type()); +// TORCH_CHECK(self.is_contiguous() && dst.is_contiguous()); + +// std::memcpy(dst.storage().data_ptr().get(), self.storage().data_ptr().get(), self.storage().nbytes()); +// return dst; +// } + + +// This macro does the heavy lifting. +// With TORCH_LIBRARY_IMPL, you can register custom kernels for your backend. +// For open registration, we're registering all of our kernels to the PrivateUse1 dispatch key. +// Later in this file, we map a custom device to the PrivateUse1 device type, +// which allows user code that puts a tensor on your custom_device to eventually get plumbed +// into the kernels registered here. +// +// This macro registers your kernels to the PyTorch Dispatcher. +// More details on the dispatcher can be found at http://blog.ezyang.com/2020/09/lets-talk-about-the-pytorch-dispatcher/. +// TORCH_LIBRARY_IMPL(aten, PrivateUse1, m) { +// m.impl("add.Tensor", &custom_add_Tensor); +// m.impl("empty.memory_format", &custom_empty_memory_format); +// m.impl("fill_.Scalar", &custom_fill__scalar); +// m.impl("_copy_from", &custom__copy_from); +// } + +// This basic implementation doesn't bother dealing with different device indices +// (e.g. custom_device:0 vs. custom_device:1). +// We could do that by letting the user pass in a device index in our exposed device function. +// Note that if you do that, you'll also need to register a device guard to core. +// See `c10/core/impl/DeviceGuardImplInterface.h:C10_REGISTER_GUARD_IMPL`. +c10::Device get_custom_device(int idx) { + return c10::Device(c10::DeviceType::PrivateUse1, idx); +} + +// Here, we're exposing a custom device object that corresponds to our custom backend. +// We do this using pybind: exposing an "extension_name.custom_device()" function in python, +// that's implemented in C++. +// The implementation in this file maps directly to the `PrivateUse1` device type. +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("custom_device", &get_custom_device, "get custom device object"); +} diff --git a/experimental/torch_xla2/torch_xla2/custom_device.py b/experimental/torch_xla2/torch_xla2/custom_device.py new file mode 100644 index 00000000000..e0713be7720 --- /dev/null +++ b/experimental/torch_xla2/torch_xla2/custom_device.py @@ -0,0 +1,26 @@ +import os +import torch +import torch.utils.cpp_extension + + +import torch_xla2.ops.jaten + + +# TODO: Do I even need a C++ module at all? +# Load the C++ extension containing your custom kernels. +# foo_module = torch.utils.cpp_extension.load( +# name="custom_device_extension", +# sources=[ +# os.path.dirname(__file__) + "/cpp/registration.cpp", +# ], +# #extra_include_paths=["cpp_extensions"], +# extra_cflags=["-g"], +# verbose=True, +# ) + +# torch.register_privateuse1_backend('foo') +torch.utils.rename_privateuse1_backend('jax') + + +# print(foo_module.Tensor) +print('Create a tensor with `jax` device:', torch.tensor([0], device='jax:0')) diff --git a/experimental/torch_xla2/torch_xla2/ops/jaten.py b/experimental/torch_xla2/torch_xla2/ops/jaten.py index 09e3c8d419c..2250e324792 100644 --- a/experimental/torch_xla2/torch_xla2/ops/jaten.py +++ b/experimental/torch_xla2/torch_xla2/ops/jaten.py @@ -51,6 +51,17 @@ def op(*aten, **kwargs): def inner(func): for a in aten: ops_registry.register_torch_dispatch_op(a, func, **kwargs) + + match type(a): + case torch._ops.OpOverloadPacket: + opname = a._qualified_op_name + case torch._ops.OpOverload: + # opname = a.name() + continue # prevent multiple funcs from being registered? + case _: + raise RuntimeError(f'oops {a}') + + torch.library.impl(opname, 'privateuseone')(func) return func return inner @@ -409,11 +420,22 @@ def _aten_dot(x, y): @op(torch.ops.aten._to_copy) def _aten__to_copy(self, **kwargs): + # HACK: should we wrap every function to do this? + import torch_xla2 + if type(self) is torch.Tensor: + return torch_xla2.default_env().j2t_iso(mappings.t2j(self)) + # return torch_xla2.tensor.XLATensor2(mappings.t2j(self)) dtype = mappings.t2j_dtype(kwargs["dtype"]) if dtype != self.dtype: return self.astype(dtype) return jnp.copy(self) +# TODO: not clear what this function should actually do +# https://github.com/pytorch/pytorch/blob/d96c80649f301129219469d8b4353e52edab3b78/aten/src/ATen/native/native_functions.yaml#L7933-L7940 +@op(torch.ops.aten.lift_fresh) +def _aten_lift_fresh(self): + return self + @op(torch.ops.aten.empty) @op_base.convert_dtype() From 146787beab9b34350ce6fec327923bb6ca62e8d9 Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Wed, 17 Jul 2024 21:45:41 +0000 Subject: [PATCH 2/2] Rely on PyTorch dispatcher --- .../torch_xla2/test/test_jax_device.py | 15 ++++++++++++ .../torch_xla2/torch_xla2/custom_device.py | 5 ---- .../torch_xla2/torch_xla2/ops/jaten.py | 24 +++++++------------ experimental/torch_xla2/torch_xla2/tensor.py | 5 ++-- 4 files changed, 26 insertions(+), 23 deletions(-) create mode 100644 experimental/torch_xla2/test/test_jax_device.py diff --git a/experimental/torch_xla2/test/test_jax_device.py b/experimental/torch_xla2/test/test_jax_device.py new file mode 100644 index 00000000000..8479d4aad19 --- /dev/null +++ b/experimental/torch_xla2/test/test_jax_device.py @@ -0,0 +1,15 @@ +import torch +import torch_xla2.custom_device + +def test_tensor_creation(): + t = torch.tensor([0], device="jax:0") + + assert t.numpy() == [0] + +def test_basic_op(): + a = torch.tensor([0], device="jax:0") + b = torch.tensor([2], device="jax:0") + + c = a + b + assert c.numpy() == [2] + diff --git a/experimental/torch_xla2/torch_xla2/custom_device.py b/experimental/torch_xla2/torch_xla2/custom_device.py index e0713be7720..3e5eab17c53 100644 --- a/experimental/torch_xla2/torch_xla2/custom_device.py +++ b/experimental/torch_xla2/torch_xla2/custom_device.py @@ -18,9 +18,4 @@ # verbose=True, # ) -# torch.register_privateuse1_backend('foo') torch.utils.rename_privateuse1_backend('jax') - - -# print(foo_module.Tensor) -print('Create a tensor with `jax` device:', torch.tensor([0], device='jax:0')) diff --git a/experimental/torch_xla2/torch_xla2/ops/jaten.py b/experimental/torch_xla2/torch_xla2/ops/jaten.py index 2250e324792..38e3edb23ba 100644 --- a/experimental/torch_xla2/torch_xla2/ops/jaten.py +++ b/experimental/torch_xla2/torch_xla2/ops/jaten.py @@ -1,5 +1,6 @@ """Torch ops implemented using jax.""" +import functools import sys from typing import Optional, Sequence @@ -9,6 +10,7 @@ import numpy as np import torch import torch.distributed._functional_collectives +from torch_xla2 import interop from torch_xla2.ops import ops_registry from torch_xla2.ops import op_base, mappings @@ -54,14 +56,15 @@ def inner(func): match type(a): case torch._ops.OpOverloadPacket: - opname = a._qualified_op_name + opname = a.default.name() if 'default' in a.overloads() else a._qualified_op_name case torch._ops.OpOverload: - # opname = a.name() - continue # prevent multiple funcs from being registered? + opname = a.name() case _: raise RuntimeError(f'oops {a}') - torch.library.impl(opname, 'privateuseone')(func) + torchfunc = functools.partial(interop.call_jax, func) + # HACK: to_copy is where we make the initial conversion from CPU tensor to JAX tensor + torch.library.impl(opname, 'privateuseone')(torchfunc if a != torch.ops.aten._to_copy else func) return func return inner @@ -108,14 +111,13 @@ def _aten_add(x, y, *, alpha=1): return x + y * alpha -@op(torch.ops.aten.copy_, torch.ops.aten.copy_.default, is_jax_function=False) +@op(torch.ops.aten.copy_, is_jax_function=False) def _aten_copy(x, y, memory_format=None): x._elem = y._elem return x @op(torch.ops.aten.clone) -@op(torch.ops.aten.clone.default) def _aten_clone(x, memory_format=None): return x @@ -469,7 +471,6 @@ def _full(size: Sequence[int], fill_value, *, dtype=None, **kwargs): @op(torch.ops.aten.empty_permuted) -@op(torch.ops.aten.empty_permuted.default) @op_base.convert_dtype() def _aten_empty_permuted(sizes, physical_layout, dtype=None, **kwargs): # Ignore the physical layout, @@ -478,7 +479,6 @@ def _aten_empty_permuted(sizes, physical_layout, dtype=None, **kwargs): @op(torch.ops.aten.empty_strided) -@op(torch.ops.aten.empty_strided.default) @op_base.convert_dtype() def _aten_empty_strided(sizes, stride, dtype=None, **kwargs): # Ignore stride, since JAX and torch tensor doesn't share the same memory. @@ -544,7 +544,6 @@ def permute(t, dims): @op(torch.ops.aten.unsqueeze) @op(torch.ops.aten.unsqueeze_copy) -@op(torch.ops.aten.unsqueeze.default) def _aten_unsqueeze(self, dim): if dim < 0: dim += self.ndim + 1 @@ -1784,7 +1783,6 @@ def _aten_ge(self, other): @op(torch.ops.aten.glu) -@op(torch.ops.aten.glu.default) def _aten_glu(x, dim=-1): return jax.nn.glu(x, dim) @@ -2206,12 +2204,6 @@ def _rand( return res -@op(torch.ops.aten.scalar_tensor.default) -def _aten_scalar_tensor(val, **kwargs): - p = torch.ops.aten.scalar_tensor(val) - return mappings.t2j(p) - - @op(torch.ops.aten.outer) def _aten_outer(a, b): return jnp.outer(a, b) diff --git a/experimental/torch_xla2/torch_xla2/tensor.py b/experimental/torch_xla2/torch_xla2/tensor.py index 3143cda8759..bceca208274 100644 --- a/experimental/torch_xla2/torch_xla2/tensor.py +++ b/experimental/torch_xla2/torch_xla2/tensor.py @@ -71,7 +71,7 @@ def __new__(cls, elem, env): cls, shape, dtype=dtype, - device='meta', + device='jax:0', requires_grad=False, ) @@ -121,7 +121,8 @@ def __torch_dispatch__(cls, func, types, args=(), kwargs=None): env = arg._env break - with env: + with mode_utils.no_dispatch(), log_nested(env, f'DISPATCH: {_name_of_func(func)}'): # env._function_mode: + print(func) return func(*args, **(kwargs or {})) def detach(self):