From d69287cd36633527e0d7edb584c52b259fd8127d Mon Sep 17 00:00:00 2001 From: Simon Teo Date: Mon, 30 Sep 2024 11:14:22 +0800 Subject: [PATCH] Fix torch.multinomial --- experimental/torch_xla2/test/test_ops.py | 5 +---- experimental/torch_xla2/torch_xla2/ops/jaten.py | 11 +++++++++++ 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/experimental/torch_xla2/test/test_ops.py b/experimental/torch_xla2/test/test_ops.py index 2d50754829b..46688b81caf 100644 --- a/experimental/torch_xla2/test/test_ops.py +++ b/experimental/torch_xla2/test/test_ops.py @@ -37,7 +37,6 @@ "igammac", "index_reduce", "kthvalue", - "lgamma", "linalg.cholesky", "linalg.cholesky_ex", "linalg.det", @@ -72,9 +71,6 @@ "max_pool2d_with_indices_backward", "min", "mode", - "multinomial", - "mvlgamma", - "nanmedian", "new_empty_strided", "nextafter", "nn.functional.adaptive_avg_pool3d", @@ -173,6 +169,7 @@ 'rand', 'rand_like', 'uniform', + 'multinomial', # Dropout is not deterministic https://pytorch.org/docs/stable/generated/torch.nn.functional.feature_alpha_dropout.html 'nn.functional.feature_alpha_dropout', } diff --git a/experimental/torch_xla2/torch_xla2/ops/jaten.py b/experimental/torch_xla2/torch_xla2/ops/jaten.py index 9195cb23759..a5bd7382c04 100644 --- a/experimental/torch_xla2/torch_xla2/ops/jaten.py +++ b/experimental/torch_xla2/torch_xla2/ops/jaten.py @@ -4032,6 +4032,17 @@ def _aten_median(self, dim=None, keepdim=False): index = _with_reduction_scalar(_get_median_index, self, dim, keepdim).astype(jnp.int64) return output, index + +@op(torch.ops.aten.nanmedian) +def _aten_nanmedian(input, dim=None, keepdim=False, *, out=None): + output = _with_reduction_scalar(functools.partial(jnp.nanquantile, q=0.5, method='lower'), input, dim=dim, keepdim=keepdim).astype(input.dtype) + if dim is None: + return output + else: + index = _with_reduction_scalar(_get_median_index, input, dim, keepdim).astype(jnp.int64) + return output, index + + def _get_median_index(x, axis=None, keepdims=False): sorted_arg = jnp.argsort(x, axis=axis) n = x.shape[axis] if axis is not None else x.size