Skip to content

Commit

Permalink
Fix torch.multinomial
Browse files Browse the repository at this point in the history
  • Loading branch information
Simon Teo committed Sep 30, 2024
1 parent 2036806 commit d69287c
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 4 deletions.
5 changes: 1 addition & 4 deletions experimental/torch_xla2/test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
"igammac",
"index_reduce",
"kthvalue",
"lgamma",
"linalg.cholesky",
"linalg.cholesky_ex",
"linalg.det",
Expand Down Expand Up @@ -72,9 +71,6 @@
"max_pool2d_with_indices_backward",
"min",
"mode",
"multinomial",
"mvlgamma",
"nanmedian",
"new_empty_strided",
"nextafter",
"nn.functional.adaptive_avg_pool3d",
Expand Down Expand Up @@ -173,6 +169,7 @@
'rand',
'rand_like',
'uniform',
'multinomial',
# Dropout is not deterministic https://pytorch.org/docs/stable/generated/torch.nn.functional.feature_alpha_dropout.html
'nn.functional.feature_alpha_dropout',
}
Expand Down
11 changes: 11 additions & 0 deletions experimental/torch_xla2/torch_xla2/ops/jaten.py
Original file line number Diff line number Diff line change
Expand Up @@ -4032,6 +4032,17 @@ def _aten_median(self, dim=None, keepdim=False):
index = _with_reduction_scalar(_get_median_index, self, dim, keepdim).astype(jnp.int64)
return output, index


@op(torch.ops.aten.nanmedian)
def _aten_nanmedian(input, dim=None, keepdim=False, *, out=None):
output = _with_reduction_scalar(functools.partial(jnp.nanquantile, q=0.5, method='lower'), input, dim=dim, keepdim=keepdim).astype(input.dtype)
if dim is None:
return output
else:
index = _with_reduction_scalar(_get_median_index, input, dim, keepdim).astype(jnp.int64)
return output, index


def _get_median_index(x, axis=None, keepdims=False):
sorted_arg = jnp.argsort(x, axis=axis)
n = x.shape[axis] if axis is not None else x.size
Expand Down

0 comments on commit d69287c

Please sign in to comment.