Skip to content

Commit

Permalink
[torchxla2] Op Info test for fft (#8018)
Browse files Browse the repository at this point in the history
  • Loading branch information
qihqi committed Sep 16, 2024
1 parent 008004d commit 3f4b31b
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 19 deletions.
19 changes: 1 addition & 18 deletions experimental/torch_xla2/test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,24 +35,6 @@
"erfinv",
"expand",
"exponential",
"fft.fft2",
"fft.fft",
"fft.fftn",
"fft.hfft2",
"fft.hfft",
"fft.hfftn",
"fft.ifft2",
"fft.ifft",
"fft.ifftn",
"fft.ihfft2",
"fft.ihfft",
"fft.ihfftn",
"fft.irfft2",
"fft.irfft",
"fft.irfftn",
"fft.rfft2",
"fft.rfft",
"fft.rfftn",
"floor_divide",
"gather",
"gcd",
Expand Down Expand Up @@ -366,6 +348,7 @@ def setUpClass(cls):

def setUp(self):
self.env = tensor.Environment()
#self.env.config.debug_accuracy_for_each_op = True
torch.manual_seed(0)

# Replaces all values in the input torch_tensor that are less than the given threshold
Expand Down
50 changes: 49 additions & 1 deletion experimental/torch_xla2/torch_xla2/ops/jaten.py
Original file line number Diff line number Diff line change
Expand Up @@ -3828,7 +3828,10 @@ def _aten_unsafe_index_put(self, indices, values, accumulate=False):
return self.index_put_(indices, values, accumulate)


@op(torch.ops.aten.conj_physical)
@op(torch.ops.aten.conj_physical,
torch.ops.aten.conj,
torch.ops.aten._conj_physical,
torch.ops.aten._conj)
def _aten_conj_physical(self):
return jnp.conjugate(self)

Expand Down Expand Up @@ -3894,3 +3897,48 @@ def _get_median_index(x, axis=None, keepdims=False):
@op(torch.ops.aten.triangular_solve)
def _aten_triangular_solve(b, a, upper=True, transpose=False, unittriangular=False):
return (jax.lax.linalg.triangular_solve(a, b, left_side=True, lower=not upper, transpose_a=transpose, unit_diagonal=unittriangular), a)


# func: _fft_c2c(Tensor self, SymInt[] dim, int normalization, bool forward) -> Tensor
@op(torch.ops.aten._fft_c2c)
def _aten__fft_c2c(self, dim, normalization, forward):
if forward:
norm = [
'backward',
'ortho',
'forward',
][normalization]
return jnp.fft.fftn(self, axes=dim, norm=norm)
else:
norm = [
'forward',
'ortho',
'backward',
][normalization]
return jnp.fft.ifftn(self, axes=dim, norm=norm)


@op(torch.ops.aten._fft_r2c)
def _aten__fft_r2c(self, dim, normalization, onesided):
norm = [
'backward',
'ortho',
'forward',
][normalization]
if onesided:
return jnp.fft.rfftn(self, axes=dim, norm=norm)
else:
return jnp.fft.fftn(self, axes=dim, norm=norm)

@op(torch.ops.aten._fft_c2r)
def _aten__fft_c2r(self, dim, normalization, last_dim_size):
norm = [
'forward',
'ortho',
'backward',
][normalization]
if len(dim) == 1:
s = [last_dim_size]
else:
s = None
return jnp.fft.irfftn(self, norm=norm, axes=dim, s=s)

0 comments on commit 3f4b31b

Please sign in to comment.