From 3006aded747f073b6eb05356c1b05b5ee0e7494b Mon Sep 17 00:00:00 2001 From: anishfish2 Date: Wed, 25 Sep 2024 19:47:05 +0000 Subject: [PATCH 1/2] Added base implementation of linalg.eigvalsh --- experimental/torch_xla2/test/test_ops.py | 1 - experimental/torch_xla2/torch_xla2/ops/jaten.py | 4 ++++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/experimental/torch_xla2/test/test_ops.py b/experimental/torch_xla2/test/test_ops.py index b71d52f055f..ecc455262df 100644 --- a/experimental/torch_xla2/test/test_ops.py +++ b/experimental/torch_xla2/test/test_ops.py @@ -42,7 +42,6 @@ "linalg.cholesky_ex", "linalg.cond", "linalg.det", - "linalg.eigvalsh", "linalg.householder_product", "linalg.inv", "linalg.inv_ex", diff --git a/experimental/torch_xla2/torch_xla2/ops/jaten.py b/experimental/torch_xla2/torch_xla2/ops/jaten.py index 8f34c675df7..eb9a8e1c43e 100644 --- a/experimental/torch_xla2/torch_xla2/ops/jaten.py +++ b/experimental/torch_xla2/torch_xla2/ops/jaten.py @@ -2202,6 +2202,10 @@ def _aten_linalg_eig(A): def _aten_linalg_eigh(A, UPLO='L'): return jnp.linalg.eigh(A, UPLO) +@op(torch.ops.aten._linalg_eigvalsh) +def _aten_linalg_eigvalsh(A, UPLO='L'): + return jnp.linalg.eigvalsh(A, UPLO) + # aten.lcm @op(torch.ops.aten.lcm) def _aten_lcm(input, other): From 3b867221d300faaa31096e5a9fab55b69f0ad953 Mon Sep 17 00:00:00 2001 From: anishfish2 Date: Wed, 25 Sep 2024 23:54:00 +0000 Subject: [PATCH 2/2] Added linalg.eigvalsh custom atol/rtol to fix removal from skiplist --- experimental/torch_xla2/test/test_ops.py | 2 +- experimental/torch_xla2/torch_xla2/ops/jaten.py | 4 ---- 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/experimental/torch_xla2/test/test_ops.py b/experimental/torch_xla2/test/test_ops.py index ecc455262df..ad705a493ac 100644 --- a/experimental/torch_xla2/test/test_ops.py +++ b/experimental/torch_xla2/test/test_ops.py @@ -181,7 +181,7 @@ 'nn.functional.feature_alpha_dropout', } -atol_dict = {"matrix_exp": (2e-1, 2e-4), "linalg.pinv": (8e-1, 2e0), "linalg.eig": (2e0, 3e0), "linalg.eigh": (5e1, 3e0)} +atol_dict = {"matrix_exp": (2e-1, 2e-4), "linalg.pinv": (8e-1, 2e0), "linalg.eig": (2e0, 3e0), "linalg.eigh": (5e1, 3e0), "linalg.eigvalsh": (5e1, 3e0)} def diff_output(testcase, output1, output2, rtol, atol, equal_nan=True, check_output=True): if isinstance(output1, torch.Tensor): diff --git a/experimental/torch_xla2/torch_xla2/ops/jaten.py b/experimental/torch_xla2/torch_xla2/ops/jaten.py index eb9a8e1c43e..8f34c675df7 100644 --- a/experimental/torch_xla2/torch_xla2/ops/jaten.py +++ b/experimental/torch_xla2/torch_xla2/ops/jaten.py @@ -2202,10 +2202,6 @@ def _aten_linalg_eig(A): def _aten_linalg_eigh(A, UPLO='L'): return jnp.linalg.eigh(A, UPLO) -@op(torch.ops.aten._linalg_eigvalsh) -def _aten_linalg_eigvalsh(A, UPLO='L'): - return jnp.linalg.eigvalsh(A, UPLO) - # aten.lcm @op(torch.ops.aten.lcm) def _aten_lcm(input, other):