From b31f5c51862d88e079ee14bf22efbe79d5645cf3 Mon Sep 17 00:00:00 2001 From: Manfei <41607353+ManfeiBai@users.noreply.github.com> Date: Tue, 17 Sep 2024 16:29:22 -0700 Subject: [PATCH 01/12] add `resize_as_` --- experimental/torch_xla2/test/test_ops.py | 1 - 1 file changed, 1 deletion(-) diff --git a/experimental/torch_xla2/test/test_ops.py b/experimental/torch_xla2/test/test_ops.py index dc7eb8eee9c..bbf3ff1048f 100644 --- a/experimental/torch_xla2/test/test_ops.py +++ b/experimental/torch_xla2/test/test_ops.py @@ -151,7 +151,6 @@ "quantile", "repeat_interleave", "resize_", - "resize_as_", "rot90", "rsub", "scatter_add", From 9c8ae53ed98e008534a6875ebf1f6941080e0973 Mon Sep 17 00:00:00 2001 From: Manfei <41607353+ManfeiBai@users.noreply.github.com> Date: Tue, 17 Sep 2024 16:30:45 -0700 Subject: [PATCH 02/12] Update test_ops.py --- experimental/torch_xla2/test/test_ops.py | 1 - 1 file changed, 1 deletion(-) diff --git a/experimental/torch_xla2/test/test_ops.py b/experimental/torch_xla2/test/test_ops.py index bbf3ff1048f..ed2e25003c4 100644 --- a/experimental/torch_xla2/test/test_ops.py +++ b/experimental/torch_xla2/test/test_ops.py @@ -150,7 +150,6 @@ "put", "quantile", "repeat_interleave", - "resize_", "rot90", "rsub", "scatter_add", From f66f8a48df8d7ae6c8c065d483a75eb1855ae2e4 Mon Sep 17 00:00:00 2001 From: Manfei <41607353+ManfeiBai@users.noreply.github.com> Date: Tue, 17 Sep 2024 16:32:35 -0700 Subject: [PATCH 03/12] Update jaten.py --- experimental/torch_xla2/torch_xla2/ops/jaten.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/experimental/torch_xla2/torch_xla2/ops/jaten.py b/experimental/torch_xla2/torch_xla2/ops/jaten.py index ea7b60484f9..ccddcfd2122 100644 --- a/experimental/torch_xla2/torch_xla2/ops/jaten.py +++ b/experimental/torch_xla2/torch_xla2/ops/jaten.py @@ -239,6 +239,16 @@ def _aten_real(x): return jnp.real(x) +@op(torch.ops.aten.resize_) +def _aten_resize_as_(x, y): + return jax.image.resize(x, size, method=interpolation) + + +@op(torch.ops.aten.resize_as_) +def _aten_resize_as_(x, y): + return jax.image.resize(x, y.shape, method='linear') + + @op(torch.ops.aten.view_as_real) def _aten_view_as_real(x): real = jnp.real(x) From e3d48880ae79e6b3f3e6d481edef98e9a9628cd1 Mon Sep 17 00:00:00 2001 From: Manfei <41607353+ManfeiBai@users.noreply.github.com> Date: Mon, 23 Sep 2024 12:16:05 -0700 Subject: [PATCH 04/12] Update jaten.py --- experimental/torch_xla2/torch_xla2/ops/jaten.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/experimental/torch_xla2/torch_xla2/ops/jaten.py b/experimental/torch_xla2/torch_xla2/ops/jaten.py index ccddcfd2122..5076bef536e 100644 --- a/experimental/torch_xla2/torch_xla2/ops/jaten.py +++ b/experimental/torch_xla2/torch_xla2/ops/jaten.py @@ -239,14 +239,15 @@ def _aten_real(x): return jnp.real(x) -@op(torch.ops.aten.resize_) -def _aten_resize_as_(x, y): - return jax.image.resize(x, size, method=interpolation) +@op(torch.Tensor.resize_) +def _aten_resize_(x, size, interpolation='linear'): + new_size = tuple(size) + return jax.numpy.resize(x, new_size) @op(torch.ops.aten.resize_as_) def _aten_resize_as_(x, y): - return jax.image.resize(x, y.shape, method='linear') + return jax.numpy.resize(x, y.shape) @op(torch.ops.aten.view_as_real) From 7c6ca1d0aff0a0dee8a8afb566afc0163136b40a Mon Sep 17 00:00:00 2001 From: Manfei <41607353+ManfeiBai@users.noreply.github.com> Date: Mon, 23 Sep 2024 13:52:51 -0700 Subject: [PATCH 05/12] Update test_ops.py --- experimental/torch_xla2/test/test_ops.py | 1 - 1 file changed, 1 deletion(-) diff --git a/experimental/torch_xla2/test/test_ops.py b/experimental/torch_xla2/test/test_ops.py index ed2e25003c4..3a904e0f0e8 100644 --- a/experimental/torch_xla2/test/test_ops.py +++ b/experimental/torch_xla2/test/test_ops.py @@ -150,7 +150,6 @@ "put", "quantile", "repeat_interleave", - "rot90", "rsub", "scatter_add", "scatter", From 5490b37dcb1644201bc34d84df41455c8ec9a836 Mon Sep 17 00:00:00 2001 From: Manfei <41607353+ManfeiBai@users.noreply.github.com> Date: Mon, 23 Sep 2024 13:53:58 -0700 Subject: [PATCH 06/12] Update tensor.py --- experimental/torch_xla2/torch_xla2/tensor.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/experimental/torch_xla2/torch_xla2/tensor.py b/experimental/torch_xla2/torch_xla2/tensor.py index f503f705f45..8089c918921 100644 --- a/experimental/torch_xla2/torch_xla2/tensor.py +++ b/experimental/torch_xla2/torch_xla2/tensor.py @@ -205,6 +205,10 @@ def __torch_function__(self, return self.env.dispatch(func, types, args, kwargs) except OperatorNotFound: pass + if _name_of_func(func) in ('rot90'): # skip rot90 with k%4==0 due to no change + if len(args) >= 2 and type(args[1]) == int: + if ((args[1])%4 == 0): + return args[0] return func(*args, **(kwargs or {})) From 6d7baf383b7933857250f78375e4153e7ffc854e Mon Sep 17 00:00:00 2001 From: Manfei <41607353+ManfeiBai@users.noreply.github.com> Date: Tue, 24 Sep 2024 13:34:33 -0700 Subject: [PATCH 07/12] Update jaten.py --- experimental/torch_xla2/torch_xla2/ops/jaten.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/experimental/torch_xla2/torch_xla2/ops/jaten.py b/experimental/torch_xla2/torch_xla2/ops/jaten.py index 5076bef536e..d56988831e3 100644 --- a/experimental/torch_xla2/torch_xla2/ops/jaten.py +++ b/experimental/torch_xla2/torch_xla2/ops/jaten.py @@ -250,6 +250,11 @@ def _aten_resize_as_(x, y): return jax.numpy.resize(x, y.shape) +@op(torch.ops.aten.repeat_interleave.Tensor) +def repeat_interleave(repeats, dim=0): + return jnp.repeat(jnp.arange(repeats.shape[dim]), repeats) + + @op(torch.ops.aten.view_as_real) def _aten_view_as_real(x): real = jnp.real(x) From 595c42f857d7fe0ea8bfd999d72981a344d72f00 Mon Sep 17 00:00:00 2001 From: Manfei <41607353+ManfeiBai@users.noreply.github.com> Date: Tue, 24 Sep 2024 13:34:53 -0700 Subject: [PATCH 08/12] Update test_ops.py --- experimental/torch_xla2/test/test_ops.py | 1 - 1 file changed, 1 deletion(-) diff --git a/experimental/torch_xla2/test/test_ops.py b/experimental/torch_xla2/test/test_ops.py index 3a904e0f0e8..7361e727b9e 100644 --- a/experimental/torch_xla2/test/test_ops.py +++ b/experimental/torch_xla2/test/test_ops.py @@ -149,7 +149,6 @@ "prod", "put", "quantile", - "repeat_interleave", "rsub", "scatter_add", "scatter", From 479d22e398aa363848ca5bb394d37786b805bf30 Mon Sep 17 00:00:00 2001 From: Manfei <41607353+ManfeiBai@users.noreply.github.com> Date: Tue, 24 Sep 2024 13:37:00 -0700 Subject: [PATCH 09/12] Update test_ops.py --- experimental/torch_xla2/test/test_ops.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/experimental/torch_xla2/test/test_ops.py b/experimental/torch_xla2/test/test_ops.py index 8ddccccdcaf..6b5380e7988 100644 --- a/experimental/torch_xla2/test/test_ops.py +++ b/experimental/torch_xla2/test/test_ops.py @@ -138,10 +138,6 @@ "polygamma", "prod", "put", - "repeat_interleave", - "resize_", - "resize_as_", - "rot90", "rsub", "scatter_reduce", "searchsorted", From 9c6c02cf95c9a1d446d3c14241766737453fcbc6 Mon Sep 17 00:00:00 2001 From: Manfei <41607353+ManfeiBai@users.noreply.github.com> Date: Tue, 24 Sep 2024 13:42:17 -0700 Subject: [PATCH 10/12] Update test_ops.py --- experimental/torch_xla2/test/test_ops.py | 1 - 1 file changed, 1 deletion(-) diff --git a/experimental/torch_xla2/test/test_ops.py b/experimental/torch_xla2/test/test_ops.py index 6b5380e7988..32c853b9b65 100644 --- a/experimental/torch_xla2/test/test_ops.py +++ b/experimental/torch_xla2/test/test_ops.py @@ -19,7 +19,6 @@ "cat", "cauchy", "cdist", - "ceil", "cholesky", "cholesky_inverse", "cholesky_solve", From 4abaaf4db0a95725ca370865e25ed07a4e3f14cd Mon Sep 17 00:00:00 2001 From: Manfei <41607353+ManfeiBai@users.noreply.github.com> Date: Tue, 24 Sep 2024 13:48:19 -0700 Subject: [PATCH 11/12] Update test_ops.py --- experimental/torch_xla2/test/test_ops.py | 1 + 1 file changed, 1 insertion(+) diff --git a/experimental/torch_xla2/test/test_ops.py b/experimental/torch_xla2/test/test_ops.py index 32c853b9b65..6b5380e7988 100644 --- a/experimental/torch_xla2/test/test_ops.py +++ b/experimental/torch_xla2/test/test_ops.py @@ -19,6 +19,7 @@ "cat", "cauchy", "cdist", + "ceil", "cholesky", "cholesky_inverse", "cholesky_solve", From 63956fc2e447006fb9143b651a36fa9afdd31f2c Mon Sep 17 00:00:00 2001 From: Manfei <41607353+ManfeiBai@users.noreply.github.com> Date: Tue, 24 Sep 2024 13:51:30 -0700 Subject: [PATCH 12/12] Update test_ops.py --- experimental/torch_xla2/test/test_ops.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/experimental/torch_xla2/test/test_ops.py b/experimental/torch_xla2/test/test_ops.py index 6b5380e7988..3654b9f689e 100644 --- a/experimental/torch_xla2/test/test_ops.py +++ b/experimental/torch_xla2/test/test_ops.py @@ -23,9 +23,7 @@ "cholesky", "cholesky_inverse", "cholesky_solve", - "combinations", "complex", - "diag_embed", "diagonal_copy", "diagonal_scatter", "digamma",