From 146787beab9b34350ce6fec327923bb6ca62e8d9 Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Wed, 17 Jul 2024 21:45:41 +0000 Subject: [PATCH] Rely on PyTorch dispatcher --- .../torch_xla2/test/test_jax_device.py | 15 ++++++++++++ .../torch_xla2/torch_xla2/custom_device.py | 5 ---- .../torch_xla2/torch_xla2/ops/jaten.py | 24 +++++++------------ experimental/torch_xla2/torch_xla2/tensor.py | 5 ++-- 4 files changed, 26 insertions(+), 23 deletions(-) create mode 100644 experimental/torch_xla2/test/test_jax_device.py diff --git a/experimental/torch_xla2/test/test_jax_device.py b/experimental/torch_xla2/test/test_jax_device.py new file mode 100644 index 00000000000..8479d4aad19 --- /dev/null +++ b/experimental/torch_xla2/test/test_jax_device.py @@ -0,0 +1,15 @@ +import torch +import torch_xla2.custom_device + +def test_tensor_creation(): + t = torch.tensor([0], device="jax:0") + + assert t.numpy() == [0] + +def test_basic_op(): + a = torch.tensor([0], device="jax:0") + b = torch.tensor([2], device="jax:0") + + c = a + b + assert c.numpy() == [2] + diff --git a/experimental/torch_xla2/torch_xla2/custom_device.py b/experimental/torch_xla2/torch_xla2/custom_device.py index e0713be7720..3e5eab17c53 100644 --- a/experimental/torch_xla2/torch_xla2/custom_device.py +++ b/experimental/torch_xla2/torch_xla2/custom_device.py @@ -18,9 +18,4 @@ # verbose=True, # ) -# torch.register_privateuse1_backend('foo') torch.utils.rename_privateuse1_backend('jax') - - -# print(foo_module.Tensor) -print('Create a tensor with `jax` device:', torch.tensor([0], device='jax:0')) diff --git a/experimental/torch_xla2/torch_xla2/ops/jaten.py b/experimental/torch_xla2/torch_xla2/ops/jaten.py index 2250e324792..38e3edb23ba 100644 --- a/experimental/torch_xla2/torch_xla2/ops/jaten.py +++ b/experimental/torch_xla2/torch_xla2/ops/jaten.py @@ -1,5 +1,6 @@ """Torch ops implemented using jax.""" +import functools import sys from typing import Optional, Sequence @@ -9,6 +10,7 @@ import numpy as np import torch import torch.distributed._functional_collectives +from torch_xla2 import interop from torch_xla2.ops import ops_registry from torch_xla2.ops import op_base, mappings @@ -54,14 +56,15 @@ def inner(func): match type(a): case torch._ops.OpOverloadPacket: - opname = a._qualified_op_name + opname = a.default.name() if 'default' in a.overloads() else a._qualified_op_name case torch._ops.OpOverload: - # opname = a.name() - continue # prevent multiple funcs from being registered? + opname = a.name() case _: raise RuntimeError(f'oops {a}') - torch.library.impl(opname, 'privateuseone')(func) + torchfunc = functools.partial(interop.call_jax, func) + # HACK: to_copy is where we make the initial conversion from CPU tensor to JAX tensor + torch.library.impl(opname, 'privateuseone')(torchfunc if a != torch.ops.aten._to_copy else func) return func return inner @@ -108,14 +111,13 @@ def _aten_add(x, y, *, alpha=1): return x + y * alpha -@op(torch.ops.aten.copy_, torch.ops.aten.copy_.default, is_jax_function=False) +@op(torch.ops.aten.copy_, is_jax_function=False) def _aten_copy(x, y, memory_format=None): x._elem = y._elem return x @op(torch.ops.aten.clone) -@op(torch.ops.aten.clone.default) def _aten_clone(x, memory_format=None): return x @@ -469,7 +471,6 @@ def _full(size: Sequence[int], fill_value, *, dtype=None, **kwargs): @op(torch.ops.aten.empty_permuted) -@op(torch.ops.aten.empty_permuted.default) @op_base.convert_dtype() def _aten_empty_permuted(sizes, physical_layout, dtype=None, **kwargs): # Ignore the physical layout, @@ -478,7 +479,6 @@ def _aten_empty_permuted(sizes, physical_layout, dtype=None, **kwargs): @op(torch.ops.aten.empty_strided) -@op(torch.ops.aten.empty_strided.default) @op_base.convert_dtype() def _aten_empty_strided(sizes, stride, dtype=None, **kwargs): # Ignore stride, since JAX and torch tensor doesn't share the same memory. @@ -544,7 +544,6 @@ def permute(t, dims): @op(torch.ops.aten.unsqueeze) @op(torch.ops.aten.unsqueeze_copy) -@op(torch.ops.aten.unsqueeze.default) def _aten_unsqueeze(self, dim): if dim < 0: dim += self.ndim + 1 @@ -1784,7 +1783,6 @@ def _aten_ge(self, other): @op(torch.ops.aten.glu) -@op(torch.ops.aten.glu.default) def _aten_glu(x, dim=-1): return jax.nn.glu(x, dim) @@ -2206,12 +2204,6 @@ def _rand( return res -@op(torch.ops.aten.scalar_tensor.default) -def _aten_scalar_tensor(val, **kwargs): - p = torch.ops.aten.scalar_tensor(val) - return mappings.t2j(p) - - @op(torch.ops.aten.outer) def _aten_outer(a, b): return jnp.outer(a, b) diff --git a/experimental/torch_xla2/torch_xla2/tensor.py b/experimental/torch_xla2/torch_xla2/tensor.py index 3143cda8759..bceca208274 100644 --- a/experimental/torch_xla2/torch_xla2/tensor.py +++ b/experimental/torch_xla2/torch_xla2/tensor.py @@ -71,7 +71,7 @@ def __new__(cls, elem, env): cls, shape, dtype=dtype, - device='meta', + device='jax:0', requires_grad=False, ) @@ -121,7 +121,8 @@ def __torch_dispatch__(cls, func, types, args=(), kwargs=None): env = arg._env break - with env: + with mode_utils.no_dispatch(), log_nested(env, f'DISPATCH: {_name_of_func(func)}'): # env._function_mode: + print(func) return func(*args, **(kwargs or {})) def detach(self):