Skip to content

Commit

Permalink
Refactor a bit so that the ops/ is a standalone module (#7250)
Browse files Browse the repository at this point in the history
  • Loading branch information
qihqi authored Jun 12, 2024
1 parent 3f54fa2 commit 8cd6ad4
Show file tree
Hide file tree
Showing 9 changed files with 133 additions and 141 deletions.
4 changes: 3 additions & 1 deletion experimental/torch_xla2/test/test_exports.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@
import torch
import torch.nn.functional as F
import jax
import jax.experimental.export
import torch_xla2
from torch_xla2 import tensor
from torch_xla2.ops import mappings


class Interpolate(torch.nn.Module):
Expand Down Expand Up @@ -125,7 +127,7 @@ def test_export_dtypes(self):
}

model = TensorConstant()
for torch_dtype in torch_xla2.tensor.TORCH_DTYPE_TO_JAX.keys():
for torch_dtype in mappings.TORCH_DTYPE_TO_JAX.keys():
if torch_dtype == None:
## TODO: Figure out what the None mapping should be, seems like:
## torch.tensor(dtype=None) maps to f32
Expand Down
1 change: 0 additions & 1 deletion experimental/torch_xla2/test/test_symbolic_shapes.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import unittest
import torch
import jax
import torch_xla2

class AddOne(torch.nn.Module):
Expand Down
15 changes: 9 additions & 6 deletions experimental/torch_xla2/torch_xla2/__init__.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
import jax
from jax._src import config
import os
import torch
from torch._functorch import make_functional
from torch.utils import _pytree as pytree
from torch_xla2 import export, tensor, tf_integration
from torch_xla2 import tensor

jax.config.update('jax_enable_x64', True)

config.update(
__all__ = [
'default_env',
'extract_jax',
]


jax.config.update('jax_enable_x64', True)
jax.config.update(
'jax_pjrt_client_create_options',
f'ml_framework_name:PyTorch/XLA2;ml_framework_version:{"v0.0.1"}'
)
Expand All @@ -29,7 +33,6 @@ def extract_jax(mod: torch.nn.Module, env=None):
"""Returns a pytree of jax.ndarray and a jax callable."""
if env is None:
env = default_env()
func, weights, buffer = make_functional.make_functional_with_buffers(mod)
states = mod.state_dict()

states = pytree.tree_map_only(torch.Tensor, tensor.t2j, states)
Expand Down
6 changes: 3 additions & 3 deletions experimental/torch_xla2/torch_xla2/ops/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
def all_aten_jax_ops():
# to load the ops
import torch_xla2.jaten # type: ignore
import torch_xla2.ops_registry # type: ignore
import torch_xla2.ops.jaten # type: ignore
import torch_xla2.ops.ops_registry # type: ignore
return {
key: val.func
for key, val in torch_xla2.ops_registry.all_aten_ops
for key, val in torch_xla2.ops.ops_registry.all_aten_ops
if val.is_jax_function
}
22 changes: 9 additions & 13 deletions experimental/torch_xla2/torch_xla2/ops/jaten.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@
import numpy as np
import torch
from torch_xla2.ops import ops_registry
from torch_xla2 import tensor
from torch_xla2.ops import op_base
from torch_xla2.ops import op_base, mappings

# Keys are OpOverload, value is a callable that takes
# XLATensor2
Expand Down Expand Up @@ -75,10 +74,7 @@ def _aten_add(x, y, *, alpha=1):

@op(torch.ops.aten.copy_, torch.ops.aten.copy_.default, is_jax_function=False)
def _aten_copy(x, y, memory_format=None):
if isinstance(x, tensor.XLATensor2):
x._elem = y._elem
elif isinstance(x, tensor.SliceView):
x.mutate(y)
x._elem = y._elem
return x


Expand Down Expand Up @@ -293,7 +289,7 @@ def _aten_dot(x, y):

@op(torch.ops.aten._to_copy)
def _aten__to_copy(self, **kwargs):
dtype = tensor.t2j_dtype(kwargs["dtype"])
dtype = mappings.t2j_dtype(kwargs["dtype"])
if dtype != self.dtype:
return self.astype(dtype)
return jnp.copy(self)
Expand Down Expand Up @@ -379,7 +375,7 @@ def _aten_ne(x, y):
@op(torch.ops.aten.cumsum)
def _aten_cumsum(x, y, dtype=None):
if dtype:
dtype = tensor.t2j_dtype(dtype)
dtype = mappings.t2j_dtype(dtype)
res = jnp.cumsum(x, y, dtype)
return res

Expand Down Expand Up @@ -1325,7 +1321,7 @@ def _aten_arange(
pin_memory=False,
):
if dtype:
dtype = tensor.t2j_dtype(dtype)
dtype = mappings.t2j_dtype(dtype)
return jnp.arange(
start,
end,
Expand Down Expand Up @@ -1477,7 +1473,7 @@ def _aten_fill(x, value, dtype=None, pin_memory=None, memory_format=None):
if dtype is None:
dtype = x.dtype
else:
dtype = tensor.t2j_dtype(dtype)
dtype = mappings.t2j_dtype(dtype)
return jnp.full(x.shape, value, dtype)


Expand Down Expand Up @@ -1772,7 +1768,7 @@ def _aten_to_dtype(
a, dtype, non_blocking=False, copy=False, memory_format=None
):
if dtype:
jaxdtype = tensor.t2j_dtype(dtype)
jaxdtype = mappings.t2j_dtype(dtype)
return a.astype(jaxdtype)


Expand All @@ -1793,7 +1789,7 @@ def _aten_scalar_tensor(
s, dtype=None, layout=None, device=None, pin_memory=None
):
if dtype is not None:
dtype = tensor.t2j_dtype(dtype)
dtype = mappings.t2j_dtype(dtype)
return jnp.array(s, dtype=dtype)
return jnp.array(s)

Expand Down Expand Up @@ -1908,7 +1904,7 @@ def _rand(
@op(torch.ops.aten.scalar_tensor.default)
def _aten_scalar_tensor(val, **kwargs):
p = torch.ops.aten.scalar_tensor(val)
return tensor.t2j(p)
return mappings.t2j(p)


@op(torch.ops.aten.to.device)
Expand Down
14 changes: 5 additions & 9 deletions experimental/torch_xla2/torch_xla2/ops/jtorch.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,16 @@
"""Tensor constructor overrides"""
import functools
from typing import Optional, Sequence
import numpy as np

import jax
import torch
import jax.numpy as jnp
import numpy as np
from torch_xla2 import tensor
from torch_xla2.ops.ops_registry import register_torch_function_op
from torch_xla2.ops import op_base
from torch_xla2 import interop

from jax.experimental.pallas.ops.tpu import flash_attention
from jax.experimental.shard_map import shard_map


import torch
from torch_xla2.ops.ops_registry import register_torch_function_op
from torch_xla2.ops import op_base, mappings


def register_function(torch_func, **kwargs):
Expand All @@ -36,7 +32,7 @@ def _tensor(data, *, dtype=None, **kwargs):
dtype = python_types_to_torch_types.get(type(leaves[0]))

return jnp.array(
data, dtype=dtype or tensor.t2j_dtype(torch.get_default_dtype()))
data, dtype=dtype or mappings.t2j_dtype(torch.get_default_dtype()))


@register_function(torch.ones)
Expand Down
94 changes: 94 additions & 0 deletions experimental/torch_xla2/torch_xla2/ops/mappings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
from jax import dlpack as jaxdl
import jax.numpy as jnp
import numpy
import torch
import torch.func
import torch.utils.dlpack as torchdl


def t2j(t):
is_bool = False
if t.dtype == torch.bool:
is_bool = True
t = t.to(torch.int8)

if not t.is_contiguous():
t = t.contiguous()

try:
dl = torchdl.to_dlpack(t)
res = jaxdl.from_dlpack(dl)
except Exception:
# https://github.com/google/jax/issues/7657
# https://github.com/google/jax/issues/17784
if t.dtype == torch.bfloat16:
nparray = (t.cpu().detach().to(torch.float32).numpy()
) # numpy don't support bfloat16
else:
nparray = t.cpu().detach().numpy()
res = jnp.asarray(nparray)
if t.dtype == torch.bfloat16:
res = res.astype(jnp.bfloat16)

if is_bool:
res = res.astype(jnp.bool_)
return res


def j2t(x):
try:
dl = jaxdl.to_dlpack(x)
res = torchdl.from_dlpack(dl)
except Exception:
res = torch.from_numpy(numpy.asarray(x))
if x.dtype == jnp.bool_:
res = res.to(torch.bool)
return res

TORCH_DTYPE_TO_JAX = {
# NO_MAPPING : jnp.float0.dtype (signless scalar int),
torch.bool : jnp.bool_.dtype,
# NO_MAPPING : jnp.int4.dtype,
torch.int8 : jnp.int8.dtype,
torch.int16 : jnp.int16.dtype,
torch.int32 : jnp.int32.dtype,
torch.int64 : jnp.int64.dtype,
torch.long : jnp.int64.dtype,
# NO_MAPPING : jnp.uint4
torch.uint8 : jnp.uint8.dtype,
torch.uint16 : jnp.uint16.dtype,
torch.uint32 : jnp.uint32.dtype,
torch.uint64 : jnp.uint64.dtype,
# NO_MAPPING : jnp.float8_e4m3b11fnuz.dtype,
torch.float8_e4m3fn : jnp.float8_e4m3fn.dtype,
# NO_MAPPING : jnp.float8_e4m3fnuz.dtype,
torch.float8_e5m2 : jnp.float8_e5m2.dtype,
# NO_MAPPING : jnp.float8_e5m2fnuz.dtype,
torch.bfloat16 : jnp.bfloat16.dtype,
torch.half : jnp.float16.dtype,
torch.float16 : jnp.float16.dtype,
torch.float32 : jnp.float32.dtype,
torch.float64 : jnp.float64.dtype,
torch.double : jnp.double.dtype,
torch.complex64 : jnp.complex64.dtype,
torch.complex128 : jnp.complex128.dtype,
None : None,
}

JAX_DTYPE_TO_TORCH = {
value: key for key, value in TORCH_DTYPE_TO_JAX.items()
}
# Add imprecise mappings for some JAX dtypes which don't have torch analogues
JAX_DTYPE_TO_TORCH[jnp.dtype('int4')] = torch.int8
JAX_DTYPE_TO_TORCH[jnp.dtype('uint4')] = torch.uint8

def t2j_dtype(dtype):
if dtype not in TORCH_DTYPE_TO_JAX:
raise RuntimeError(f'Attempting to convert unknown type: {dtype} to torch type,')
return TORCH_DTYPE_TO_JAX[dtype]


def j2t_dtype(dtype):
if dtype not in JAX_DTYPE_TO_TORCH:
raise RuntimeError(f'Attempting to convert unknown type: {dtype} to torch type,')
return JAX_DTYPE_TO_TORCH[dtype]
34 changes: 3 additions & 31 deletions experimental/torch_xla2/torch_xla2/ops/op_base.py
Original file line number Diff line number Diff line change
@@ -1,37 +1,9 @@
import functools
import torch
from torch_xla2 import interop, tensor
from torch_xla2.ops import mappings
from torch_xla2 import types

from typing import Callable, Optional, ParamSpec, Sequence


class BinaryOpWithPromotion:

def __init__(self, inner):
self.inner = inner

def _get_dtype(self, obj):
if isinstance(obj, torch.Tensor):
return obj.dtype
else:
if isinstance(obj, float):
return torch.float32
if isinstance(obj, int):
return torch.int32
return torch.float32


def __call__(self, *args, **kwargs):
# args are torch.Tensor
res = interop.torch_view(self.jax)(*args, **kwargs)

dtype = torch.promote_types(
self._get_dtype(args[0]),
self._get_dtype(args[1]))
if dtype != res.dtype:
res = res.to(dtype)
return res
from typing import Optional, ParamSpec


class InplaceOp:
Expand Down Expand Up @@ -75,7 +47,7 @@ def wrapper(*args: P.args,
**kwargs: P.kwargs):
if not dtype and use_default_dtype:
dtype = torch.get_default_dtype()
jax_dtype = tensor.t2j_dtype(dtype)
jax_dtype = mappings.t2j_dtype(dtype)

return func(*args, dtype=jax_dtype, **kwargs)

Expand Down
Loading

0 comments on commit 8cd6ad4

Please sign in to comment.