diff --git a/.github/workflows/torch_xla2.yml b/.github/workflows/torch_xla2.yml index b93ecba50ee..8099eb0bd2a 100644 --- a/.github/workflows/torch_xla2.yml +++ b/.github/workflows/torch_xla2.yml @@ -35,7 +35,7 @@ jobs: working-directory: experimental/torch_xla2 run: | pip install pytest absl-py jax[cpu] flatbuffers tensorflow - pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu + pip install torch --index-url https://download.pytorch.org/whl/cpu pip install -e . - name: Run tests working-directory: experimental/torch_xla2 diff --git a/experimental/torch_xla2/pyproject.toml b/experimental/torch_xla2/pyproject.toml index 112e169e2c6..d87a5fc8530 100644 --- a/experimental/torch_xla2/pyproject.toml +++ b/experimental/torch_xla2/pyproject.toml @@ -7,9 +7,13 @@ build-backend = "hatchling.build" version = "0.0.1" name = "torch_xla2" dependencies = [ - "torch>=2.1", + "absl-py", + "flatbuffers", "jax>=0.4.24", - "jaxlib", + "jaxlib>=0.4.24", + "pytest", + "tensorflow", + "torch", ] requires-python = ">=3.10" diff --git a/experimental/torch_xla2/requirements.txt b/experimental/torch_xla2/requirements.txt index 54d87fb6430..7656284e7bb 100644 --- a/experimental/torch_xla2/requirements.txt +++ b/experimental/torch_xla2/requirements.txt @@ -1,22 +1,7 @@ -certifi==2022.12.7 -charset-normalizer==2.1.1 -filelock==3.9.0 -fsspec==2023.4.0 -idna==3.4 -jax==0.4.24 -jaxlib==0.4.24 -Jinja2==3.1.2 -MarkupSafe==2.1.3 -ml-dtypes==0.3.2 -mpmath==1.2.1 -networkx==3.0rc1 -numpy==1.24.1 -opt-einsum==3.3.0 -Pillow==9.3.0 -requests==2.28.1 -scipy==1.12.0 -sympy==1.11.1 ---pre -torch -typing_extensions==4.8.0 -urllib3==1.26.13 +absl-py==2.0.0 +flatbuffers==23.5.26 +jax==0.4.23 +jaxlib==0.4.23 +pytest +tensorflow +torch==2.2.1+cpu \ No newline at end of file diff --git a/experimental/torch_xla2/test/test_extra.py b/experimental/torch_xla2/test/test_extra.py index 0d68762bc04..f0e9ea174a4 100644 --- a/experimental/torch_xla2/test/test_extra.py +++ b/experimental/torch_xla2/test/test_extra.py @@ -15,16 +15,14 @@ def test_fori_loop(self): a = tensor.move_to_device(torch.ones((10, 10))) def body(i, c): - return c + a[i] + return c + a[i] init_val = tensor.move_to_device(torch.zeros(10)) res = extra.fori_loop(0, 10, body, init_val) - + expect = torch.ones(10) * 10 self.assertTrue(torch.allclose(tensor.j2t(res._elem), expect)) - - if __name__ == '__main__': diff --git a/experimental/torch_xla2/torch_xla2/__init__.py b/experimental/torch_xla2/torch_xla2/__init__.py index 1cbc67cf650..b450d151bba 100644 --- a/experimental/torch_xla2/torch_xla2/__init__.py +++ b/experimental/torch_xla2/torch_xla2/__init__.py @@ -6,6 +6,7 @@ from torch_xla2 import export, ops, ops_registry, tensor, tf_integration + def extract_jax(mod: torch.nn.Module): """Returns a pytree of jax.ndarray and a jax callable.""" func, weights, buffer = make_functional.make_functional_with_buffers(mod) diff --git a/experimental/torch_xla2/torch_xla2/decompositions.py b/experimental/torch_xla2/torch_xla2/decompositions.py new file mode 100644 index 00000000000..bbcf26e6240 --- /dev/null +++ b/experimental/torch_xla2/torch_xla2/decompositions.py @@ -0,0 +1,86 @@ +"""This file contains some decompositons that are not available in torch stable. + +Most likely from Content of +https://github.com/pytorch/pytorch/blob/main/torch/_decomp/decompositions.py +at main branch HEAD that we find useful here. + +Can also contain decompositions of a torch op in terms of other torch ops. +""" + +from typing import Any, Callable, List, Tuple + +import torch +from torch import Tensor +import torch._decomp as decomp +from torch._decomp import register_decomposition +import torch._prims_common as utils +from torch._prims_common.wrappers import out_wrapper + + +DispatchKey = torch._C.DispatchKey # type: ignore[attr-defined] + +# None of these functions are publicly accessible; get at them +# from torch._decomps +__all__: List[str] = [] + +aten = torch._ops.ops.aten + +@register_decomposition(aten.reflection_pad1d) +@register_decomposition(aten.reflection_pad2d) +@register_decomposition(aten.reflection_pad3d) +@out_wrapper() +def _reflection_pad(a: Tensor, padding: Tuple[int, ...]) -> Tensor: + def idx(left, middle, right): + dim_idx = torch.arange(-left, middle + right, device=a.device) + return middle - 1 - (middle - 1 - dim_idx.abs()).abs() + + return _reflection_or_replication_pad( + a, + padding, + idx, + ) + + +@register_decomposition(aten.replication_pad1d) +@register_decomposition(aten.replication_pad3d) +@out_wrapper() +def _replication_pad(a: Tensor, padding: Tuple[int, ...]) -> Tensor: + def idx(left, middle, right): + dim_idx = torch.arange(-left, middle + right, device=a.device) + return torch.clamp(dim_idx, 0, middle - 1) + + return _reflection_or_replication_pad( + a, + padding, + idx, + ) + +decomp.global_decomposition_table['post_autograd'][aten.replication_pad2d.default] = _replication_pad + + +def _reflection_or_replication_pad( + a: Tensor, + padding: Tuple[int, ...], + idx_fn: Callable[[int, int, int], Tensor], +) -> Tensor: + dim = len(padding) // 2 + torch._check( + a.dim() in (dim + 1, dim + 2), + lambda: f"reflection_pad{dim}d requires {dim + 1}D or {dim + 2}D input", + ) + inp_shape = a.shape[-dim:] + nc_dim = a.dim() - dim + + padding_left = [padding[2 * (dim - 1 - i)] for i in range(dim)] + padding_right = [padding[2 * (dim - 1 - i) + 1] for i in range(dim)] + + result = a + for i in range(dim): + idx: List[Any] = [None] * result.dim() + idx[i + nc_dim] = idx_fn(padding_left[i], inp_shape[i], padding_right[i]) + result = aten._unsafe_index(result, idx) + + # convert output to correct memory format, if necessary + memory_format = utils.suggest_memory_format(result) + result = result.contiguous(memory_format=memory_format) + return result \ No newline at end of file diff --git a/experimental/torch_xla2/torch_xla2/export.py b/experimental/torch_xla2/torch_xla2/export.py index 01f7285a814..64a3f9d175c 100644 --- a/experimental/torch_xla2/torch_xla2/export.py +++ b/experimental/torch_xla2/torch_xla2/export.py @@ -210,9 +210,8 @@ def exported_program_to_jax(exported_program, export_raw: bool = False): names, states = _extract_states_from_exported_program(exported_program) def _extract_args(args, kwargs): - flat_args_with_path, received_spec = pytree.tree_flatten_with_path( + flat_args, received_spec = pytree.tree_flatten( (args, kwargs)) # type: ignore[possibly-undefined] - flat_args = [x[1] for x in flat_args_with_path] return flat_args num_mutations = len(exported_program.graph_signature.buffers_to_mutate) diff --git a/experimental/torch_xla2/torch_xla2/ops_registry.py b/experimental/torch_xla2/torch_xla2/ops_registry.py index 53ac3a4f6d9..f1d115864d3 100644 --- a/experimental/torch_xla2/torch_xla2/ops_registry.py +++ b/experimental/torch_xla2/torch_xla2/ops_registry.py @@ -1,6 +1,6 @@ import torch import torch._decomp as decomp - +import torch_xla2.decompositions class LoweringRegistry: @@ -38,6 +38,12 @@ def register(self, op, lowering): torch.ops.aten._adaptive_avg_pool3d, torch.ops.aten.grid_sampler_2d, torch.ops.aten.native_dropout, + torch.ops.aten.reflection_pad1d, + torch.ops.aten.reflection_pad2d, + torch.ops.aten.reflection_pad3d, + torch.ops.aten.replication_pad1d, + torch.ops.aten.replication_pad2d, + torch.ops.aten.replication_pad3d, ]) CORE_ATEN_DECOMP = decomp.core_aten_decompositions() CORE_ATEN_DECOMP.update(EXTRA_DECOMP)