Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make sure torch_xla2 works with torch stable 2.2 #6593

Merged
merged 1 commit into from
Feb 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/torch_xla2.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ jobs:
working-directory: experimental/torch_xla2
run: |
pip install pytest absl-py jax[cpu] flatbuffers tensorflow
pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu
pip install torch --index-url https://download.pytorch.org/whl/cpu
pip install -e .
- name: Run tests
working-directory: experimental/torch_xla2
Expand Down
8 changes: 6 additions & 2 deletions experimental/torch_xla2/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,13 @@ build-backend = "hatchling.build"
version = "0.0.1"
name = "torch_xla2"
dependencies = [
"torch>=2.1",
"absl-py",
"flatbuffers",
"jax>=0.4.24",
"jaxlib",
"jaxlib>=0.4.24",
"pytest",
"tensorflow",
"torch",
]

requires-python = ">=3.10"
Expand Down
29 changes: 7 additions & 22 deletions experimental/torch_xla2/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,22 +1,7 @@
certifi==2022.12.7
charset-normalizer==2.1.1
filelock==3.9.0
fsspec==2023.4.0
idna==3.4
jax==0.4.24
jaxlib==0.4.24
Jinja2==3.1.2
MarkupSafe==2.1.3
ml-dtypes==0.3.2
mpmath==1.2.1
networkx==3.0rc1
numpy==1.24.1
opt-einsum==3.3.0
Pillow==9.3.0
requests==2.28.1
scipy==1.12.0
sympy==1.11.1
--pre
torch
typing_extensions==4.8.0
urllib3==1.26.13
absl-py==2.0.0
flatbuffers==23.5.26
jax==0.4.23
jaxlib==0.4.23
pytest
tensorflow
torch==2.2.1+cpu
6 changes: 2 additions & 4 deletions experimental/torch_xla2/test/test_extra.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,14 @@ def test_fori_loop(self):
a = tensor.move_to_device(torch.ones((10, 10)))

def body(i, c):
return c + a[i]
return c + a[i]

init_val = tensor.move_to_device(torch.zeros(10))
res = extra.fori_loop(0, 10, body, init_val)

expect = torch.ones(10) * 10

self.assertTrue(torch.allclose(tensor.j2t(res._elem), expect))




if __name__ == '__main__':
Expand Down
1 change: 1 addition & 0 deletions experimental/torch_xla2/torch_xla2/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from torch_xla2 import export, ops, ops_registry, tensor, tf_integration



def extract_jax(mod: torch.nn.Module):
"""Returns a pytree of jax.ndarray and a jax callable."""
func, weights, buffer = make_functional.make_functional_with_buffers(mod)
Expand Down
86 changes: 86 additions & 0 deletions experimental/torch_xla2/torch_xla2/decompositions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
"""This file contains some decompositons that are not available in torch stable.

Most likely from Content of
https://github.com/pytorch/pytorch/blob/main/torch/_decomp/decompositions.py
at main branch HEAD that we find useful here.

Can also contain decompositions of a torch op in terms of other torch ops.
"""

from typing import Any, Callable, List, Tuple

import torch
from torch import Tensor
import torch._decomp as decomp
from torch._decomp import register_decomposition
import torch._prims_common as utils
from torch._prims_common.wrappers import out_wrapper


DispatchKey = torch._C.DispatchKey # type: ignore[attr-defined]

# None of these functions are publicly accessible; get at them
# from torch._decomps
__all__: List[str] = []

aten = torch._ops.ops.aten

@register_decomposition(aten.reflection_pad1d)
@register_decomposition(aten.reflection_pad2d)
@register_decomposition(aten.reflection_pad3d)
@out_wrapper()
def _reflection_pad(a: Tensor, padding: Tuple[int, ...]) -> Tensor:
def idx(left, middle, right):
dim_idx = torch.arange(-left, middle + right, device=a.device)
return middle - 1 - (middle - 1 - dim_idx.abs()).abs()

return _reflection_or_replication_pad(
a,
padding,
idx,
)


@register_decomposition(aten.replication_pad1d)
@register_decomposition(aten.replication_pad3d)
@out_wrapper()
def _replication_pad(a: Tensor, padding: Tuple[int, ...]) -> Tensor:
def idx(left, middle, right):
dim_idx = torch.arange(-left, middle + right, device=a.device)
return torch.clamp(dim_idx, 0, middle - 1)

return _reflection_or_replication_pad(
a,
padding,
idx,
)

decomp.global_decomposition_table['post_autograd'][aten.replication_pad2d.default] = _replication_pad


def _reflection_or_replication_pad(
a: Tensor,
padding: Tuple[int, ...],
idx_fn: Callable[[int, int, int], Tensor],
) -> Tensor:
dim = len(padding) // 2
torch._check(
a.dim() in (dim + 1, dim + 2),
lambda: f"reflection_pad{dim}d requires {dim + 1}D or {dim + 2}D input",
)
inp_shape = a.shape[-dim:]
nc_dim = a.dim() - dim

padding_left = [padding[2 * (dim - 1 - i)] for i in range(dim)]
padding_right = [padding[2 * (dim - 1 - i) + 1] for i in range(dim)]

result = a
for i in range(dim):
idx: List[Any] = [None] * result.dim()
idx[i + nc_dim] = idx_fn(padding_left[i], inp_shape[i], padding_right[i])
result = aten._unsafe_index(result, idx)

# convert output to correct memory format, if necessary
memory_format = utils.suggest_memory_format(result)
result = result.contiguous(memory_format=memory_format)
return result
3 changes: 1 addition & 2 deletions experimental/torch_xla2/torch_xla2/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,9 +210,8 @@ def exported_program_to_jax(exported_program, export_raw: bool = False):
names, states = _extract_states_from_exported_program(exported_program)

def _extract_args(args, kwargs):
flat_args_with_path, received_spec = pytree.tree_flatten_with_path(
flat_args, received_spec = pytree.tree_flatten(
(args, kwargs)) # type: ignore[possibly-undefined]
flat_args = [x[1] for x in flat_args_with_path]
return flat_args

num_mutations = len(exported_program.graph_signature.buffers_to_mutate)
Expand Down
8 changes: 7 additions & 1 deletion experimental/torch_xla2/torch_xla2/ops_registry.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import torch
import torch._decomp as decomp

import torch_xla2.decompositions

class LoweringRegistry:

Expand Down Expand Up @@ -38,6 +38,12 @@ def register(self, op, lowering):
torch.ops.aten._adaptive_avg_pool3d,
torch.ops.aten.grid_sampler_2d,
torch.ops.aten.native_dropout,
torch.ops.aten.reflection_pad1d,
torch.ops.aten.reflection_pad2d,
torch.ops.aten.reflection_pad3d,
torch.ops.aten.replication_pad1d,
torch.ops.aten.replication_pad2d,
torch.ops.aten.replication_pad3d,
])
CORE_ATEN_DECOMP = decomp.core_aten_decompositions()
CORE_ATEN_DECOMP.update(EXTRA_DECOMP)
Expand Down
Loading