Skip to content

Commit

Permalink
Make sure torch_xla2 works with torch stable 2.2
Browse files Browse the repository at this point in the history
  • Loading branch information
qihqi committed Feb 22, 2024
1 parent 5145ea1 commit 0a24487
Show file tree
Hide file tree
Showing 8 changed files with 111 additions and 32 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/torch_xla2.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ jobs:
working-directory: experimental/torch_xla2
run: |
pip install pytest absl-py jax[cpu] flatbuffers tensorflow
pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu
pip install torch --index-url https://download.pytorch.org/whl/cpu
pip install -e .
- name: Run tests
working-directory: experimental/torch_xla2
Expand Down
8 changes: 6 additions & 2 deletions experimental/torch_xla2/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,13 @@ build-backend = "hatchling.build"
version = "0.0.1"
name = "torch_xla2"
dependencies = [
"torch>=2.1",
"absl-py",
"flatbuffers",
"jax>=0.4.24",
"jaxlib",
"jaxlib>=0.4.24",
"pytest",
"tensorflow",
"torch",
]

requires-python = ">=3.10"
Expand Down
29 changes: 7 additions & 22 deletions experimental/torch_xla2/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,22 +1,7 @@
certifi==2022.12.7
charset-normalizer==2.1.1
filelock==3.9.0
fsspec==2023.4.0
idna==3.4
jax==0.4.24
jaxlib==0.4.24
Jinja2==3.1.2
MarkupSafe==2.1.3
ml-dtypes==0.3.2
mpmath==1.2.1
networkx==3.0rc1
numpy==1.24.1
opt-einsum==3.3.0
Pillow==9.3.0
requests==2.28.1
scipy==1.12.0
sympy==1.11.1
--pre
torch
typing_extensions==4.8.0
urllib3==1.26.13
absl-py==2.0.0
flatbuffers==23.5.26
jax==0.4.23
jaxlib==0.4.23
pytest
tensorflow
torch==2.2.1+cpu
6 changes: 2 additions & 4 deletions experimental/torch_xla2/test/test_extra.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,14 @@ def test_fori_loop(self):
a = tensor.move_to_device(torch.ones((10, 10)))

def body(i, c):
return c + a[i]
return c + a[i]

init_val = tensor.move_to_device(torch.zeros(10))
res = extra.fori_loop(0, 10, body, init_val)

expect = torch.ones(10) * 10

self.assertTrue(torch.allclose(tensor.j2t(res._elem), expect))




if __name__ == '__main__':
Expand Down
1 change: 1 addition & 0 deletions experimental/torch_xla2/torch_xla2/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from torch_xla2 import export, ops, ops_registry, tensor, tf_integration



def extract_jax(mod: torch.nn.Module):
"""Returns a pytree of jax.ndarray and a jax callable."""
func, weights, buffer = make_functional.make_functional_with_buffers(mod)
Expand Down
86 changes: 86 additions & 0 deletions experimental/torch_xla2/torch_xla2/decompositions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
"""This file contains some decompositons that are not available in torch stable.
Most likely from Content of
https://github.com/pytorch/pytorch/blob/main/torch/_decomp/decompositions.py
at main branch HEAD that we find useful here.
Can also contain decompositions of a torch op in terms of other torch ops.
"""

from typing import Any, Callable, List, Tuple

import torch
from torch import Tensor
import torch._decomp as decomp
from torch._decomp import register_decomposition
import torch._prims_common as utils
from torch._prims_common.wrappers import out_wrapper


DispatchKey = torch._C.DispatchKey # type: ignore[attr-defined]

# None of these functions are publicly accessible; get at them
# from torch._decomps
__all__: List[str] = []

aten = torch._ops.ops.aten

@register_decomposition(aten.reflection_pad1d)
@register_decomposition(aten.reflection_pad2d)
@register_decomposition(aten.reflection_pad3d)
@out_wrapper()
def _reflection_pad(a: Tensor, padding: Tuple[int, ...]) -> Tensor:
def idx(left, middle, right):
dim_idx = torch.arange(-left, middle + right, device=a.device)
return middle - 1 - (middle - 1 - dim_idx.abs()).abs()

return _reflection_or_replication_pad(
a,
padding,
idx,
)


@register_decomposition(aten.replication_pad1d)
@register_decomposition(aten.replication_pad3d)
@out_wrapper()
def _replication_pad(a: Tensor, padding: Tuple[int, ...]) -> Tensor:
def idx(left, middle, right):
dim_idx = torch.arange(-left, middle + right, device=a.device)
return torch.clamp(dim_idx, 0, middle - 1)

return _reflection_or_replication_pad(
a,
padding,
idx,
)

decomp.global_decomposition_table['post_autograd'][aten.replication_pad2d.default] = _replication_pad


def _reflection_or_replication_pad(
a: Tensor,
padding: Tuple[int, ...],
idx_fn: Callable[[int, int, int], Tensor],
) -> Tensor:
dim = len(padding) // 2
torch._check(
a.dim() in (dim + 1, dim + 2),
lambda: f"reflection_pad{dim}d requires {dim + 1}D or {dim + 2}D input",
)
inp_shape = a.shape[-dim:]
nc_dim = a.dim() - dim

padding_left = [padding[2 * (dim - 1 - i)] for i in range(dim)]
padding_right = [padding[2 * (dim - 1 - i) + 1] for i in range(dim)]

result = a
for i in range(dim):
idx: List[Any] = [None] * result.dim()
idx[i + nc_dim] = idx_fn(padding_left[i], inp_shape[i], padding_right[i])
result = aten._unsafe_index(result, idx)

# convert output to correct memory format, if necessary
memory_format = utils.suggest_memory_format(result)
result = result.contiguous(memory_format=memory_format)
return result
3 changes: 1 addition & 2 deletions experimental/torch_xla2/torch_xla2/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,9 +210,8 @@ def exported_program_to_jax(exported_program, export_raw: bool = False):
names, states = _extract_states_from_exported_program(exported_program)

def _extract_args(args, kwargs):
flat_args_with_path, received_spec = pytree.tree_flatten_with_path(
flat_args, received_spec = pytree.tree_flatten(
(args, kwargs)) # type: ignore[possibly-undefined]
flat_args = [x[1] for x in flat_args_with_path]
return flat_args

num_mutations = len(exported_program.graph_signature.buffers_to_mutate)
Expand Down
8 changes: 7 additions & 1 deletion experimental/torch_xla2/torch_xla2/ops_registry.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import torch
import torch._decomp as decomp

import torch_xla2.decompositions

class LoweringRegistry:

Expand Down Expand Up @@ -38,6 +38,12 @@ def register(self, op, lowering):
torch.ops.aten._adaptive_avg_pool3d,
torch.ops.aten.grid_sampler_2d,
torch.ops.aten.native_dropout,
torch.ops.aten.reflection_pad1d,
torch.ops.aten.reflection_pad2d,
torch.ops.aten.reflection_pad3d,
torch.ops.aten.replication_pad1d,
torch.ops.aten.replication_pad2d,
torch.ops.aten.replication_pad3d,
])
CORE_ATEN_DECOMP = decomp.core_aten_decompositions()
CORE_ATEN_DECOMP.update(EXTRA_DECOMP)
Expand Down

0 comments on commit 0a24487

Please sign in to comment.