Skip to content

Commit

Permalink
Fix fft ops
Browse files Browse the repository at this point in the history
  • Loading branch information
qihqi committed Sep 16, 2024
1 parent 2a5897d commit fc5d475
Show file tree
Hide file tree
Showing 3 changed files with 134 additions and 15 deletions.
15 changes: 1 addition & 14 deletions experimental/torch_xla2/test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,24 +36,10 @@
"erfinv",
"expand",
"exponential",
"fft.fft2",
"fft.fft",
"fft.fftn",
"fft.hfft2",
"fft.hfft",
"fft.hfftn",
"fft.ifft2",
"fft.ifft",
"fft.ifftn",
"fft.ihfft2",
"fft.ihfft",
"fft.ihfftn",
"fft.irfft2",
"fft.irfft",
"fft.irfftn",
"fft.rfft2",
"fft.rfft",
"fft.rfftn",
"floor_divide",
"gather",
"gcd",
Expand Down Expand Up @@ -370,6 +356,7 @@ def setUpClass(cls):

def setUp(self):
self.env = tensor.Environment()
#self.env.config.debug_accuracy_for_each_op = True
torch.manual_seed(0)

# Replaces all values in the input torch_tensor that are less than the given threshold
Expand Down
53 changes: 52 additions & 1 deletion experimental/torch_xla2/torch_xla2/ops/jaten.py
Original file line number Diff line number Diff line change
Expand Up @@ -3821,7 +3821,10 @@ def _aten_unsafe_index_put(self, indices, values, accumulate=False):
return self.index_put_(indices, values, accumulate)


@op(torch.ops.aten.conj_physical)
@op(torch.ops.aten.conj_physical,
torch.ops.aten.conj,
torch.ops.aten._conj_physical,
torch.ops.aten._conj)
def _aten_conj_physical(self):
return jnp.conjugate(self)

Expand Down Expand Up @@ -3887,3 +3890,51 @@ def _get_median_index(x, axis=None, keepdims=False):
@op(torch.ops.aten.triangular_solve)
def _aten_triangular_solve(b, a, upper=True, transpose=False, unittriangular=False):
return (jax.lax.linalg.triangular_solve(a, b, left_side=True, lower=not upper, transpose_a=transpose, unit_diagonal=unittriangular), a)


# func: _fft_c2c(Tensor self, SymInt[] dim, int normalization, bool forward) -> Tensor
@op(torch.ops.aten._fft_c2c)
def _aten__fft_c2c(self, dim, normalization, forward):
if forward:
norm = [
'backward',
'ortho',
'forward',
][normalization]
return jnp.fft.fftn(self, axes=dim, norm=norm)
else:
norm = [
'forward',
'ortho',
'backward',
][normalization]
return jnp.fft.ifftn(self, axes=dim, norm=norm)


@op(torch.ops.aten._fft_r2c)
def _aten__fft_r2c(self, dim, normalization, onesided):
norm = [
'backward',
'ortho',
'forward',
][normalization]
if onesided:
return jnp.fft.rfftn(self, axes=dim, norm=norm)
else:
return jnp.fft.fftn(self, axes=dim, norm=norm)

@op(torch.ops.aten._fft_c2r)
def _aten__fft_c2r(self, dim, normalization, last_dim_size):
norm = [
'forward',
'ortho',
'backward',
][normalization]
if len(dim) == 1:
s = [last_dim_size]
else:
s = None
return jnp.fft.irfftn(self, norm=norm, axes=dim, s=s)



81 changes: 81 additions & 0 deletions experimental/torch_xla2/torch_xla2/ops/jtorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,3 +283,84 @@ def linalg_slogdet(input):
@register_function(torch.tensor_split)
def tensor_split(input, indices_or_sections, dim=0):
return jnp.array_split(input, indices_or_sections, axis=dim)


# @register_function(torch.fft.irfft)
# def irfft(input, n=None, dim=-1, norm=None):
# # int64 casts to float32
# if input.dtype == jnp.int64.dtype:
# input = input.astype(jnp.float32.dtype)
# return jnp.fft.irfft(input, n, dim, norm)

# @register_function(torch.fft.rfft)
# def rfft(input, n=None, dim=-1, norm=None):
# # int64 casts to float32
# if input.dtype == jnp.int64.dtype:
# input = input.astype(jnp.float32.dtype)
# return jnp.fft.rfft(input, n, dim, norm)

# @register_function(torch.fft.fft)
# def fft(input, n=None, dim=-1, norm=None):
# # int64 casts to float32
# if input.dtype == jnp.int64.dtype:
# input = input.astype(jnp.float32.dtype)
# return jnp.fft.fft(input, n, dim, norm)


# @register_function(torch.fft.ihfft)
# def ihfft(input, n=None, dim=-1, norm=None):
# if input.dtype == jnp.int64.dtype:
# input = input.astype(jnp.float32.dtype)
# return jnp.fft.ihfft(input, n, dim, norm)

# @register_function(torch.fft.hfft)
# def hfft(input, n=None, dim=-1, norm=None):
# if input.dtype == jnp.int64.dtype:
# input = input.astype(jnp.float32.dtype)
# return jnp.fft.hfft(input, n, dim, norm)


# @register_function(torch.fft.hfft2)
# def hfft2(input, s=None, dim=(-2, -1), norm=None, *, out=None):
# return hfftn(input, s, dim, norm)

# _SWAP_DIRECTION_MAP = {"backward": "forward", None: "forward",
# "ortho": "ortho", "forward": "backward"}

# @register_function(torch.fft.hfftn)
# def hfftn(input, s=None, dim=None, norm=None, *, out=None):
# dim = dim or list(range(input.ndim))
# if isinstance(dim, int):
# dim = [dim]
# if input.dtype == jnp.int64.dtype:
# input = input.astype(jnp.float32.dtype)


# res = input
# for i, d in enumerate(dim):
# size = s[i] if s else None
# res = hfft(res, size, dim=d) # no normalize

# if norm == 'forward':
# res = res / np.prod(res.shape)
# if norm == 'ortho':
# res = res / np.sqrt(np.prod(res.shape))
# if input.dtype == jnp.float32.dtype:
# res = res.astype(jnp.float32.dtype)
# return res


# @register_function(torch.fft.ihfft2)
# def ihfft2(input, s=None, dim=(-2, -1), norm=None, *, out=None):
# if isinstance(dim, tuple):
# d1, d2 = dim
# else:
# d1, d2 = dim, dim
# shape = input.shape
# if s is not None:
# s1, s2 = s
# else:
# s1, s2 = shape[d1] // 2 + 1, shape[d2] // 2 + 1
# return jnp.fft.ifft2(
# input, s, dim, norm
# )[..., :s1, :s2]

0 comments on commit fc5d475

Please sign in to comment.