diff --git a/experimental/torch_xla2/test/test_ops.py b/experimental/torch_xla2/test/test_ops.py index 2b88ae324fa..d8f18432ba8 100644 --- a/experimental/torch_xla2/test/test_ops.py +++ b/experimental/torch_xla2/test/test_ops.py @@ -36,24 +36,10 @@ "erfinv", "expand", "exponential", - "fft.fft2", - "fft.fft", - "fft.fftn", "fft.hfft2", - "fft.hfft", "fft.hfftn", - "fft.ifft2", - "fft.ifft", - "fft.ifftn", "fft.ihfft2", - "fft.ihfft", "fft.ihfftn", - "fft.irfft2", - "fft.irfft", - "fft.irfftn", - "fft.rfft2", - "fft.rfft", - "fft.rfftn", "floor_divide", "gather", "gcd", @@ -370,6 +356,7 @@ def setUpClass(cls): def setUp(self): self.env = tensor.Environment() + #self.env.config.debug_accuracy_for_each_op = True torch.manual_seed(0) # Replaces all values in the input torch_tensor that are less than the given threshold diff --git a/experimental/torch_xla2/torch_xla2/ops/jaten.py b/experimental/torch_xla2/torch_xla2/ops/jaten.py index 5e094d39927..0fd34dbcb0d 100644 --- a/experimental/torch_xla2/torch_xla2/ops/jaten.py +++ b/experimental/torch_xla2/torch_xla2/ops/jaten.py @@ -3821,7 +3821,10 @@ def _aten_unsafe_index_put(self, indices, values, accumulate=False): return self.index_put_(indices, values, accumulate) -@op(torch.ops.aten.conj_physical) +@op(torch.ops.aten.conj_physical, + torch.ops.aten.conj, + torch.ops.aten._conj_physical, + torch.ops.aten._conj) def _aten_conj_physical(self): return jnp.conjugate(self) @@ -3887,3 +3890,51 @@ def _get_median_index(x, axis=None, keepdims=False): @op(torch.ops.aten.triangular_solve) def _aten_triangular_solve(b, a, upper=True, transpose=False, unittriangular=False): return (jax.lax.linalg.triangular_solve(a, b, left_side=True, lower=not upper, transpose_a=transpose, unit_diagonal=unittriangular), a) + + +# func: _fft_c2c(Tensor self, SymInt[] dim, int normalization, bool forward) -> Tensor +@op(torch.ops.aten._fft_c2c) +def _aten__fft_c2c(self, dim, normalization, forward): + if forward: + norm = [ + 'backward', + 'ortho', + 'forward', + ][normalization] + return jnp.fft.fftn(self, axes=dim, norm=norm) + else: + norm = [ + 'forward', + 'ortho', + 'backward', + ][normalization] + return jnp.fft.ifftn(self, axes=dim, norm=norm) + + +@op(torch.ops.aten._fft_r2c) +def _aten__fft_r2c(self, dim, normalization, onesided): + norm = [ + 'backward', + 'ortho', + 'forward', + ][normalization] + if onesided: + return jnp.fft.rfftn(self, axes=dim, norm=norm) + else: + return jnp.fft.fftn(self, axes=dim, norm=norm) + +@op(torch.ops.aten._fft_c2r) +def _aten__fft_c2r(self, dim, normalization, last_dim_size): + norm = [ + 'forward', + 'ortho', + 'backward', + ][normalization] + if len(dim) == 1: + s = [last_dim_size] + else: + s = None + return jnp.fft.irfftn(self, norm=norm, axes=dim, s=s) + + + diff --git a/experimental/torch_xla2/torch_xla2/ops/jtorch.py b/experimental/torch_xla2/torch_xla2/ops/jtorch.py index d53e522bc07..5c186cf85f3 100644 --- a/experimental/torch_xla2/torch_xla2/ops/jtorch.py +++ b/experimental/torch_xla2/torch_xla2/ops/jtorch.py @@ -283,3 +283,84 @@ def linalg_slogdet(input): @register_function(torch.tensor_split) def tensor_split(input, indices_or_sections, dim=0): return jnp.array_split(input, indices_or_sections, axis=dim) + + +# @register_function(torch.fft.irfft) +# def irfft(input, n=None, dim=-1, norm=None): +# # int64 casts to float32 +# if input.dtype == jnp.int64.dtype: +# input = input.astype(jnp.float32.dtype) +# return jnp.fft.irfft(input, n, dim, norm) + +# @register_function(torch.fft.rfft) +# def rfft(input, n=None, dim=-1, norm=None): +# # int64 casts to float32 +# if input.dtype == jnp.int64.dtype: +# input = input.astype(jnp.float32.dtype) +# return jnp.fft.rfft(input, n, dim, norm) + +# @register_function(torch.fft.fft) +# def fft(input, n=None, dim=-1, norm=None): +# # int64 casts to float32 +# if input.dtype == jnp.int64.dtype: +# input = input.astype(jnp.float32.dtype) +# return jnp.fft.fft(input, n, dim, norm) + + +# @register_function(torch.fft.ihfft) +# def ihfft(input, n=None, dim=-1, norm=None): +# if input.dtype == jnp.int64.dtype: +# input = input.astype(jnp.float32.dtype) +# return jnp.fft.ihfft(input, n, dim, norm) + +# @register_function(torch.fft.hfft) +# def hfft(input, n=None, dim=-1, norm=None): +# if input.dtype == jnp.int64.dtype: +# input = input.astype(jnp.float32.dtype) +# return jnp.fft.hfft(input, n, dim, norm) + + +# @register_function(torch.fft.hfft2) +# def hfft2(input, s=None, dim=(-2, -1), norm=None, *, out=None): +# return hfftn(input, s, dim, norm) + +# _SWAP_DIRECTION_MAP = {"backward": "forward", None: "forward", +# "ortho": "ortho", "forward": "backward"} + +# @register_function(torch.fft.hfftn) +# def hfftn(input, s=None, dim=None, norm=None, *, out=None): +# dim = dim or list(range(input.ndim)) +# if isinstance(dim, int): +# dim = [dim] +# if input.dtype == jnp.int64.dtype: +# input = input.astype(jnp.float32.dtype) + + +# res = input +# for i, d in enumerate(dim): +# size = s[i] if s else None +# res = hfft(res, size, dim=d) # no normalize + +# if norm == 'forward': +# res = res / np.prod(res.shape) +# if norm == 'ortho': +# res = res / np.sqrt(np.prod(res.shape)) +# if input.dtype == jnp.float32.dtype: +# res = res.astype(jnp.float32.dtype) +# return res + + +# @register_function(torch.fft.ihfft2) +# def ihfft2(input, s=None, dim=(-2, -1), norm=None, *, out=None): +# if isinstance(dim, tuple): +# d1, d2 = dim +# else: +# d1, d2 = dim, dim +# shape = input.shape +# if s is not None: +# s1, s2 = s +# else: +# s1, s2 = shape[d1] // 2 + 1, shape[d2] // 2 + 1 +# return jnp.fft.ifft2( +# input, s, dim, norm +# )[..., :s1, :s2]