Skip to content

Commit

Permalink
Rely on PyTorch dispatcher
Browse files Browse the repository at this point in the history
  • Loading branch information
will-cromar committed Jul 17, 2024
1 parent 160a4d2 commit 146787b
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 23 deletions.
15 changes: 15 additions & 0 deletions experimental/torch_xla2/test/test_jax_device.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import torch
import torch_xla2.custom_device

def test_tensor_creation():
t = torch.tensor([0], device="jax:0")

assert t.numpy() == [0]

def test_basic_op():
a = torch.tensor([0], device="jax:0")
b = torch.tensor([2], device="jax:0")

c = a + b
assert c.numpy() == [2]

5 changes: 0 additions & 5 deletions experimental/torch_xla2/torch_xla2/custom_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,4 @@
# verbose=True,
# )

# torch.register_privateuse1_backend('foo')
torch.utils.rename_privateuse1_backend('jax')


# print(foo_module.Tensor)
print('Create a tensor with `jax` device:', torch.tensor([0], device='jax:0'))
24 changes: 8 additions & 16 deletions experimental/torch_xla2/torch_xla2/ops/jaten.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Torch ops implemented using jax."""

import functools
import sys
from typing import Optional, Sequence

Expand All @@ -9,6 +10,7 @@
import numpy as np
import torch
import torch.distributed._functional_collectives
from torch_xla2 import interop
from torch_xla2.ops import ops_registry
from torch_xla2.ops import op_base, mappings

Expand Down Expand Up @@ -54,14 +56,15 @@ def inner(func):

match type(a):
case torch._ops.OpOverloadPacket:
opname = a._qualified_op_name
opname = a.default.name() if 'default' in a.overloads() else a._qualified_op_name
case torch._ops.OpOverload:
# opname = a.name()
continue # prevent multiple funcs from being registered?
opname = a.name()
case _:
raise RuntimeError(f'oops {a}')

torch.library.impl(opname, 'privateuseone')(func)
torchfunc = functools.partial(interop.call_jax, func)
# HACK: to_copy is where we make the initial conversion from CPU tensor to JAX tensor
torch.library.impl(opname, 'privateuseone')(torchfunc if a != torch.ops.aten._to_copy else func)
return func

return inner
Expand Down Expand Up @@ -108,14 +111,13 @@ def _aten_add(x, y, *, alpha=1):
return x + y * alpha


@op(torch.ops.aten.copy_, torch.ops.aten.copy_.default, is_jax_function=False)
@op(torch.ops.aten.copy_, is_jax_function=False)
def _aten_copy(x, y, memory_format=None):
x._elem = y._elem
return x


@op(torch.ops.aten.clone)
@op(torch.ops.aten.clone.default)
def _aten_clone(x, memory_format=None):
return x

Expand Down Expand Up @@ -469,7 +471,6 @@ def _full(size: Sequence[int], fill_value, *, dtype=None, **kwargs):


@op(torch.ops.aten.empty_permuted)
@op(torch.ops.aten.empty_permuted.default)
@op_base.convert_dtype()
def _aten_empty_permuted(sizes, physical_layout, dtype=None, **kwargs):
# Ignore the physical layout,
Expand All @@ -478,7 +479,6 @@ def _aten_empty_permuted(sizes, physical_layout, dtype=None, **kwargs):


@op(torch.ops.aten.empty_strided)
@op(torch.ops.aten.empty_strided.default)
@op_base.convert_dtype()
def _aten_empty_strided(sizes, stride, dtype=None, **kwargs):
# Ignore stride, since JAX and torch tensor doesn't share the same memory.
Expand Down Expand Up @@ -544,7 +544,6 @@ def permute(t, dims):

@op(torch.ops.aten.unsqueeze)
@op(torch.ops.aten.unsqueeze_copy)
@op(torch.ops.aten.unsqueeze.default)
def _aten_unsqueeze(self, dim):
if dim < 0:
dim += self.ndim + 1
Expand Down Expand Up @@ -1784,7 +1783,6 @@ def _aten_ge(self, other):


@op(torch.ops.aten.glu)
@op(torch.ops.aten.glu.default)
def _aten_glu(x, dim=-1):
return jax.nn.glu(x, dim)

Expand Down Expand Up @@ -2206,12 +2204,6 @@ def _rand(
return res


@op(torch.ops.aten.scalar_tensor.default)
def _aten_scalar_tensor(val, **kwargs):
p = torch.ops.aten.scalar_tensor(val)
return mappings.t2j(p)


@op(torch.ops.aten.outer)
def _aten_outer(a, b):
return jnp.outer(a, b)
Expand Down
5 changes: 3 additions & 2 deletions experimental/torch_xla2/torch_xla2/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def __new__(cls, elem, env):
cls,
shape,
dtype=dtype,
device='meta',
device='jax:0',
requires_grad=False,
)

Expand Down Expand Up @@ -121,7 +121,8 @@ def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
env = arg._env
break

with env:
with mode_utils.no_dispatch(), log_nested(env, f'DISPATCH: {_name_of_func(func)}'): # env._function_mode:
print(func)
return func(*args, **(kwargs or {}))

def detach(self):
Expand Down

0 comments on commit 146787b

Please sign in to comment.