From ed6838172d6f883970a2ead0f811d0aff1f827c9 Mon Sep 17 00:00:00 2001 From: matinehAkhlaghinia Date: Fri, 27 Sep 2024 18:14:56 +0100 Subject: [PATCH] Add support for erfinv (#8081) --- experimental/torch_xla2/test/test_ops.py | 1 - experimental/torch_xla2/torch_xla2/ops/jaten.py | 6 ++++++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/experimental/torch_xla2/test/test_ops.py b/experimental/torch_xla2/test/test_ops.py index 2a96b010ab7..5d12932012f 100644 --- a/experimental/torch_xla2/test/test_ops.py +++ b/experimental/torch_xla2/test/test_ops.py @@ -28,7 +28,6 @@ "diagonal_copy", "diagonal_scatter", "digamma", - "erfinv", "exponential", "gcd", "geometric", diff --git a/experimental/torch_xla2/torch_xla2/ops/jaten.py b/experimental/torch_xla2/torch_xla2/ops/jaten.py index 8f34c675df7..eb9b04cf1c6 100644 --- a/experimental/torch_xla2/torch_xla2/ops/jaten.py +++ b/experimental/torch_xla2/torch_xla2/ops/jaten.py @@ -2057,6 +2057,12 @@ def _aten_erf(x): return jax.lax.erf(x) +@op(torch.ops.aten.erfinv) +@op_base.promote_int_input +def _aten_erfinv(input): + return jax.lax.erf_inv(input) + + # aten.exp @op(torch.ops.aten.exp) def _aten_exp(input):