Skip to content

Commit

Permalink
Fix torch.mode tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Simon Teo committed Sep 28, 2024
1 parent 35f510d commit 3b02527
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 2 deletions.
2 changes: 0 additions & 2 deletions experimental/torch_xla2/test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,6 @@
"lu_unpack",
"masked.median",
"max_pool2d_with_indices_backward",
"min",
"mode",
"multinomial",
"mvlgamma",
"nanmedian",
Expand Down
22 changes: 22 additions & 0 deletions experimental/torch_xla2/torch_xla2/ops/jaten.py
Original file line number Diff line number Diff line change
Expand Up @@ -1067,6 +1067,28 @@ def _aten_min(x, dim=None, keepdim=False):
return _with_reduction_scalar(jnp.min, x, dim, keepdim)


@op(torch.ops.aten.mode)
def _aten_mode(input, dim=-1, keepdim=False, *, out=None):
print("INPUT SHAPE: ", input.shape)
if input.ndim == 0:
keepdim = False # no dimensions to keep
dim = 0
input = jnp.expand_dims(input, 0)
else:
dim = (input.ndim + dim) % input.ndim
# keepdims must be True for accurate broadcasting
mode, _ = jax.scipy.stats.mode(input, axis=dim, keepdims=True)
mode_broadcast = jnp.broadcast_to(mode, input.shape)
# find last occurence of mode value
flip_input = jnp.flip(input, axis=dim)
indices = jnp.argmax(jnp.equal(mode_broadcast, flip_input), axis=dim, keepdims=keepdim)
len_array = jnp.ones(indices.shape).astype(jnp.int64) * (input.shape[dim] - 1)
res_indices = len_array - indices
if not keepdim:
mode, _ = jax.scipy.stats.mode(input, axis=dim, keepdims=False)
return mode, res_indices


@op(torch.ops.aten.amin)
def _aten_amin(x, dim=None, keepdim=False):
return _with_reduction_scalar(jnp.amin, x, dim, keepdim)
Expand Down

0 comments on commit 3b02527

Please sign in to comment.