From 95ebe5558f1e540556048e2f3807ecb4fc59aa44 Mon Sep 17 00:00:00 2001 From: Matin Akhlaghinia Date: Tue, 17 Sep 2024 17:25:59 +0100 Subject: [PATCH 1/4] Add support for native_layer_norm --- experimental/torch_xla2/test/test_ops.py | 1 - 1 file changed, 1 deletion(-) diff --git a/experimental/torch_xla2/test/test_ops.py b/experimental/torch_xla2/test/test_ops.py index 3b95476d413..80f250e9c83 100644 --- a/experimental/torch_xla2/test/test_ops.py +++ b/experimental/torch_xla2/test/test_ops.py @@ -97,7 +97,6 @@ "mvlgamma", "nanmedian", "nanquantile", - "native_layer_norm", "new_empty", "new_empty_strided", "nextafter", From 031c3bcd354a7b45da7f45c07cfc6cf4f7307b0e Mon Sep 17 00:00:00 2001 From: Matin Akhlaghinia Date: Wed, 18 Sep 2024 16:57:43 +0100 Subject: [PATCH 2/4] Add support for quantile and nanquantile. --- experimental/torch_xla2/test/test_ops.py | 2 -- experimental/torch_xla2/torch_xla2/ops/jaten.py | 11 ++++++----- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/experimental/torch_xla2/test/test_ops.py b/experimental/torch_xla2/test/test_ops.py index 80f250e9c83..58f6938e952 100644 --- a/experimental/torch_xla2/test/test_ops.py +++ b/experimental/torch_xla2/test/test_ops.py @@ -96,7 +96,6 @@ "multinomial", "mvlgamma", "nanmedian", - "nanquantile", "new_empty", "new_empty_strided", "nextafter", @@ -167,7 +166,6 @@ "polygamma", "prod", "put", - "quantile", "repeat_interleave", "resize_", "resize_as_", diff --git a/experimental/torch_xla2/torch_xla2/ops/jaten.py b/experimental/torch_xla2/torch_xla2/ops/jaten.py index a7ffa8ddec9..8d44f8f994e 100644 --- a/experimental/torch_xla2/torch_xla2/ops/jaten.py +++ b/experimental/torch_xla2/torch_xla2/ops/jaten.py @@ -40,6 +40,10 @@ torch.ops.aten.bernoulli_: torch.ops.aten.bernoulli.p, torch.ops.aten.clamp_: torch.ops.aten.clamp, torch.ops.aten.random_: torch.ops.aten.uniform, + torch.ops.aten.ceil_: torch.ops.aten.ceil, + torch.ops.aten.logical_not_: torch.ops.aten.logical_not, + torch.ops.aten.unsqueeze_: torch.ops.aten.unsqueeze, + torch.ops.aten.transpose_: torch.ops.aten.transpose, } @@ -1476,11 +1480,6 @@ def _aten_pixel_shuffle(x, upscale_factor): def _aten_lt(self, other): return self < other -# aten.logical_not_ -@op(torch.ops.aten.logical_not_) -def _aten_logical_not_(input): - return jnp.logical_not(input) - def pool(inputs, init, reduce_fn, window_shape, strides, padding): """Helper function to define pooling functions. @@ -2043,6 +2042,8 @@ def _aten_frexp(input): # aten.gather @op(torch.ops.aten.gather) def _aten_gather(input, dim, index): + if dim < 0: + dim += input.ndim input_indexes, source_indexes = _scatter_index(dim, index) return input[input_indexes] From f2e90ceae8162a2ffb88438a2bb7d85ac42bfc9d Mon Sep 17 00:00:00 2001 From: Matin Akhlaghinia Date: Wed, 18 Sep 2024 19:01:10 +0100 Subject: [PATCH 3/4] Fix gather --- experimental/torch_xla2/test/test_ops.py | 1 - experimental/torch_xla2/torch_xla2/ops/jaten.py | 4 +++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/experimental/torch_xla2/test/test_ops.py b/experimental/torch_xla2/test/test_ops.py index 58f6938e952..a306a1f3340 100644 --- a/experimental/torch_xla2/test/test_ops.py +++ b/experimental/torch_xla2/test/test_ops.py @@ -36,7 +36,6 @@ "expand", "exponential", "floor_divide", - "gather", "gcd", "geometric", "geqrf", diff --git a/experimental/torch_xla2/torch_xla2/ops/jaten.py b/experimental/torch_xla2/torch_xla2/ops/jaten.py index 8d44f8f994e..99bf89f268d 100644 --- a/experimental/torch_xla2/torch_xla2/ops/jaten.py +++ b/experimental/torch_xla2/torch_xla2/ops/jaten.py @@ -2042,8 +2042,10 @@ def _aten_frexp(input): # aten.gather @op(torch.ops.aten.gather) def _aten_gather(input, dim, index): + if input.ndim == 0: + return jnp.broadcast_to(input, index.shape) if dim < 0: - dim += input.ndim + dim += input.ndim input_indexes, source_indexes = _scatter_index(dim, index) return input[input_indexes] From b741db1a244256a8889b0570f08697aa876a3f78 Mon Sep 17 00:00:00 2001 From: Matin Akhlaghinia Date: Fri, 27 Sep 2024 16:37:21 +0100 Subject: [PATCH 4/4] Add support for erfinv --- experimental/torch_xla2/test/test_ops.py | 1 - experimental/torch_xla2/torch_xla2/ops/jaten.py | 6 ++++++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/experimental/torch_xla2/test/test_ops.py b/experimental/torch_xla2/test/test_ops.py index a306a1f3340..91e0a4546d9 100644 --- a/experimental/torch_xla2/test/test_ops.py +++ b/experimental/torch_xla2/test/test_ops.py @@ -32,7 +32,6 @@ "diff", "digamma", "erfc", - "erfinv", "expand", "exponential", "floor_divide", diff --git a/experimental/torch_xla2/torch_xla2/ops/jaten.py b/experimental/torch_xla2/torch_xla2/ops/jaten.py index 99bf89f268d..8cfef24f0ba 100644 --- a/experimental/torch_xla2/torch_xla2/ops/jaten.py +++ b/experimental/torch_xla2/torch_xla2/ops/jaten.py @@ -1959,6 +1959,12 @@ def _aten_erf(x): return jax.lax.erf(x) +@op(torch.ops.aten.erfinv) +@op_base.promote_int_input +def _aten_erfinv(input): + return jax.lax.erf_inv(input) + + # aten.exp @op(torch.ops.aten.exp) def _aten_exp(input):