Skip to content

Commit

Permalink
Fix quantile, nanquantile and gather (#8040)
Browse files Browse the repository at this point in the history
  • Loading branch information
matinehAkhlaghinia committed Sep 18, 2024
1 parent 0415cc6 commit 6b33f8f
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 8 deletions.
3 changes: 0 additions & 3 deletions experimental/torch_xla2/test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
"digamma",
"erfinv",
"exponential",
"gather",
"gcd",
"geometric",
"geqrf",
Expand Down Expand Up @@ -88,7 +87,6 @@
"multinomial",
"mvlgamma",
"nanmedian",
"nanquantile",
"new_empty",
"new_empty_strided",
"nextafter",
Expand Down Expand Up @@ -148,7 +146,6 @@
"polygamma",
"prod",
"put",
"quantile",
"repeat_interleave",
"resize_",
"resize_as_",
Expand Down
13 changes: 8 additions & 5 deletions experimental/torch_xla2/torch_xla2/ops/jaten.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,10 @@
torch.ops.aten.bernoulli_: torch.ops.aten.bernoulli.p,
torch.ops.aten.clamp_: torch.ops.aten.clamp,
torch.ops.aten.random_: torch.ops.aten.uniform,
torch.ops.aten.ceil_: torch.ops.aten.ceil,
torch.ops.aten.logical_not_: torch.ops.aten.logical_not,
torch.ops.aten.unsqueeze_: torch.ops.aten.unsqueeze,
torch.ops.aten.transpose_: torch.ops.aten.transpose,
}


Expand Down Expand Up @@ -1476,11 +1480,6 @@ def _aten_pixel_shuffle(x, upscale_factor):
def _aten_lt(self, other):
return self < other

# aten.logical_not_
@op(torch.ops.aten.logical_not_)
def _aten_logical_not_(input):
return jnp.logical_not(input)


def pool(inputs, init, reduce_fn, window_shape, strides, padding):
"""Helper function to define pooling functions.
Expand Down Expand Up @@ -2043,6 +2042,10 @@ def _aten_frexp(input):
# aten.gather
@op(torch.ops.aten.gather)
def _aten_gather(input, dim, index):
if input.ndim == 0:
return jnp.broadcast_to(input, index.shape)
if dim < 0:
dim += input.ndim
input_indexes, source_indexes = _scatter_index(dim, index)
return input[input_indexes]

Expand Down

0 comments on commit 6b33f8f

Please sign in to comment.