From 38434404f2546e9a6837b07ceec3c733abe27885 Mon Sep 17 00:00:00 2001 From: anishfish2 Date: Wed, 25 Sep 2024 18:31:20 +0000 Subject: [PATCH] Fixed linalg.eig by adding lowering and adding custom atol/rtol --- experimental/torch_xla2/test/test_ops.py | 3 +-- experimental/torch_xla2/torch_xla2/ops/jaten.py | 3 +++ 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/experimental/torch_xla2/test/test_ops.py b/experimental/torch_xla2/test/test_ops.py index 1c67da372c1..b71d52f055f 100644 --- a/experimental/torch_xla2/test/test_ops.py +++ b/experimental/torch_xla2/test/test_ops.py @@ -42,7 +42,6 @@ "linalg.cholesky_ex", "linalg.cond", "linalg.det", - "linalg.eigh", "linalg.eigvalsh", "linalg.householder_product", "linalg.inv", @@ -183,7 +182,7 @@ 'nn.functional.feature_alpha_dropout', } -atol_dict = {"matrix_exp": (2e-1, 2e-4), "linalg.pinv": (8e-1, 2e0), "linalg.eig":(2e0, 3e0)} +atol_dict = {"matrix_exp": (2e-1, 2e-4), "linalg.pinv": (8e-1, 2e0), "linalg.eig": (2e0, 3e0), "linalg.eigh": (5e1, 3e0)} def diff_output(testcase, output1, output2, rtol, atol, equal_nan=True, check_output=True): if isinstance(output1, torch.Tensor): diff --git a/experimental/torch_xla2/torch_xla2/ops/jaten.py b/experimental/torch_xla2/torch_xla2/ops/jaten.py index ac56cc03fa9..8f34c675df7 100644 --- a/experimental/torch_xla2/torch_xla2/ops/jaten.py +++ b/experimental/torch_xla2/torch_xla2/ops/jaten.py @@ -2198,6 +2198,9 @@ def _aten_igamma(input, other): def _aten_linalg_eig(A): return jnp.linalg.eig(A) +@op(torch.ops.aten._linalg_eigh) +def _aten_linalg_eigh(A, UPLO='L'): + return jnp.linalg.eigh(A, UPLO) # aten.lcm @op(torch.ops.aten.lcm)