diff --git a/experimental/torch_xla2/test/test_ops.py b/experimental/torch_xla2/test/test_ops.py index 8ddccccdcaf..3654b9f689e 100644 --- a/experimental/torch_xla2/test/test_ops.py +++ b/experimental/torch_xla2/test/test_ops.py @@ -23,9 +23,7 @@ "cholesky", "cholesky_inverse", "cholesky_solve", - "combinations", "complex", - "diag_embed", "diagonal_copy", "diagonal_scatter", "digamma", @@ -138,10 +136,6 @@ "polygamma", "prod", "put", - "repeat_interleave", - "resize_", - "resize_as_", - "rot90", "rsub", "scatter_reduce", "searchsorted", diff --git a/experimental/torch_xla2/torch_xla2/ops/jaten.py b/experimental/torch_xla2/torch_xla2/ops/jaten.py index 54adbd30e65..6c3accd7de2 100644 --- a/experimental/torch_xla2/torch_xla2/ops/jaten.py +++ b/experimental/torch_xla2/torch_xla2/ops/jaten.py @@ -243,6 +243,22 @@ def _aten_real(x): return jnp.real(x) +@op(torch.Tensor.resize_) +def _aten_resize_(x, size, interpolation='linear'): + new_size = tuple(size) + return jax.numpy.resize(x, new_size) + + +@op(torch.ops.aten.resize_as_) +def _aten_resize_as_(x, y): + return jax.numpy.resize(x, y.shape) + + +@op(torch.ops.aten.repeat_interleave.Tensor) +def repeat_interleave(repeats, dim=0): + return jnp.repeat(jnp.arange(repeats.shape[dim]), repeats) + + @op(torch.ops.aten.view_as_real) def _aten_view_as_real(x): real = jnp.real(x) diff --git a/experimental/torch_xla2/torch_xla2/tensor.py b/experimental/torch_xla2/torch_xla2/tensor.py index ddd79af627a..861dd8aaf89 100644 --- a/experimental/torch_xla2/torch_xla2/tensor.py +++ b/experimental/torch_xla2/torch_xla2/tensor.py @@ -205,6 +205,10 @@ def __torch_function__(self, return self.env.dispatch(func, types, args, kwargs) except OperatorNotFound: pass + if _name_of_func(func) in ('rot90'): # skip rot90 with k%4==0 due to no change + if len(args) >= 2 and type(args[1]) == int: + if ((args[1])%4 == 0): + return args[0] return func(*args, **(kwargs or {}))