From 3f4b31bb53aeddd4fa6e1875227971d338e7ac1a Mon Sep 17 00:00:00 2001 From: qihqi Date: Mon, 16 Sep 2024 10:18:51 -0700 Subject: [PATCH] [torchxla2] Op Info test for fft (#8018) --- experimental/torch_xla2/test/test_ops.py | 19 +------ .../torch_xla2/torch_xla2/ops/jaten.py | 50 ++++++++++++++++++- 2 files changed, 50 insertions(+), 19 deletions(-) diff --git a/experimental/torch_xla2/test/test_ops.py b/experimental/torch_xla2/test/test_ops.py index 0ed93d2df13..fa7d9dada4c 100644 --- a/experimental/torch_xla2/test/test_ops.py +++ b/experimental/torch_xla2/test/test_ops.py @@ -35,24 +35,6 @@ "erfinv", "expand", "exponential", - "fft.fft2", - "fft.fft", - "fft.fftn", - "fft.hfft2", - "fft.hfft", - "fft.hfftn", - "fft.ifft2", - "fft.ifft", - "fft.ifftn", - "fft.ihfft2", - "fft.ihfft", - "fft.ihfftn", - "fft.irfft2", - "fft.irfft", - "fft.irfftn", - "fft.rfft2", - "fft.rfft", - "fft.rfftn", "floor_divide", "gather", "gcd", @@ -366,6 +348,7 @@ def setUpClass(cls): def setUp(self): self.env = tensor.Environment() + #self.env.config.debug_accuracy_for_each_op = True torch.manual_seed(0) # Replaces all values in the input torch_tensor that are less than the given threshold diff --git a/experimental/torch_xla2/torch_xla2/ops/jaten.py b/experimental/torch_xla2/torch_xla2/ops/jaten.py index 0d015135786..c43232444cb 100644 --- a/experimental/torch_xla2/torch_xla2/ops/jaten.py +++ b/experimental/torch_xla2/torch_xla2/ops/jaten.py @@ -3828,7 +3828,10 @@ def _aten_unsafe_index_put(self, indices, values, accumulate=False): return self.index_put_(indices, values, accumulate) -@op(torch.ops.aten.conj_physical) +@op(torch.ops.aten.conj_physical, + torch.ops.aten.conj, + torch.ops.aten._conj_physical, + torch.ops.aten._conj) def _aten_conj_physical(self): return jnp.conjugate(self) @@ -3894,3 +3897,48 @@ def _get_median_index(x, axis=None, keepdims=False): @op(torch.ops.aten.triangular_solve) def _aten_triangular_solve(b, a, upper=True, transpose=False, unittriangular=False): return (jax.lax.linalg.triangular_solve(a, b, left_side=True, lower=not upper, transpose_a=transpose, unit_diagonal=unittriangular), a) + + +# func: _fft_c2c(Tensor self, SymInt[] dim, int normalization, bool forward) -> Tensor +@op(torch.ops.aten._fft_c2c) +def _aten__fft_c2c(self, dim, normalization, forward): + if forward: + norm = [ + 'backward', + 'ortho', + 'forward', + ][normalization] + return jnp.fft.fftn(self, axes=dim, norm=norm) + else: + norm = [ + 'forward', + 'ortho', + 'backward', + ][normalization] + return jnp.fft.ifftn(self, axes=dim, norm=norm) + + +@op(torch.ops.aten._fft_r2c) +def _aten__fft_r2c(self, dim, normalization, onesided): + norm = [ + 'backward', + 'ortho', + 'forward', + ][normalization] + if onesided: + return jnp.fft.rfftn(self, axes=dim, norm=norm) + else: + return jnp.fft.fftn(self, axes=dim, norm=norm) + +@op(torch.ops.aten._fft_c2r) +def _aten__fft_c2r(self, dim, normalization, last_dim_size): + norm = [ + 'forward', + 'ortho', + 'backward', + ][normalization] + if len(dim) == 1: + s = [last_dim_size] + else: + s = None + return jnp.fft.irfftn(self, norm=norm, axes=dim, s=s)