Skip to content

Commit

Permalink
Fix torch.multinomial, torch.mvlgamma, torch.lgamma, torch.nanmedian
Browse files Browse the repository at this point in the history
  • Loading branch information
Simon Teo committed Oct 1, 2024
1 parent 940bee4 commit d7f304e
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 4 deletions.
5 changes: 1 addition & 4 deletions experimental/torch_xla2/test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
"igammac",
"index_reduce",
"kthvalue",
"lgamma",
"linalg.cholesky",
"linalg.cholesky_ex",
"linalg.det",
Expand Down Expand Up @@ -69,9 +68,6 @@
"lu_unpack",
"masked.median",
"max_pool2d_with_indices_backward",
"multinomial",
"mvlgamma",
"nanmedian",
"new_empty_strided",
"nextafter",
"nn.functional.adaptive_avg_pool3d",
Expand Down Expand Up @@ -169,6 +165,7 @@
'rand',
'rand_like',
'uniform',
'multinomial',
# Dropout is not deterministic https://pytorch.org/docs/stable/generated/torch.nn.functional.feature_alpha_dropout.html
'nn.functional.feature_alpha_dropout',
}
Expand Down
31 changes: 31 additions & 0 deletions experimental/torch_xla2/torch_xla2/ops/jaten.py
Original file line number Diff line number Diff line change
Expand Up @@ -2232,6 +2232,13 @@ def _aten_hypot(input, other):
def _aten_igamma(input, other):
return jax.scipy.special.gammainc(input, other)

@op(torch.ops.aten.lgamma)
def _aten_lgamma(input, *, out=None):
return jax.scipy.special.gammaln(input).astype(jnp.float32)

@op(torch.ops.aten.mvlgamma)
def _aten_mvlgamma(input, p, *, out=None):
return jax.scipy.special.multigammaln(input, d)

@op(torch.ops.aten.linalg_eig)
def _aten_linalg_eig(A):
Expand Down Expand Up @@ -3935,6 +3942,19 @@ def f(k, carry):
return vectorized(self, n.astype(jnp.int64))


@op(torch.ops.aten.multinomial, needs_env=True)
def _aten_multinomial(input, num_samples, replacement=False, *, generator=None, out=None, env=None):
assert num_samples <= input.shape[-1] or replacement, "cannot take a larger sample than population when replacement=False"
assert jnp.all(input >= 0), "inputs must be non-negative"
key = env.get_and_rotate_prng_key(generator)
if input.ndim == 1:
assert jnp.sum(input) > 0, "rows of input must have non-zero sum"
return jax.random.choice(key, input.shape[-1], (num_samples,), replace=replacement, p=input)
else:
assert jnp.all(jnp.sum(input, axis=1) > 0), "rows of input must have non-zero sum"
return jnp.array([jax.random.choice(key, input.shape[-1], (num_samples,), replace=replacement, p=input[i, :]) for i in range(input.shape[0])])


@op(torch.ops.aten.narrow)
@op(torch.ops.aten.narrow_copy)
def _aten_narrow(input, dim, start, length):
Expand Down Expand Up @@ -4047,6 +4067,17 @@ def _aten_median(self, dim=None, keepdim=False):
index = _with_reduction_scalar(_get_median_index, self, dim, keepdim).astype(jnp.int64)
return output, index


@op(torch.ops.aten.nanmedian)
def _aten_nanmedian(input, dim=None, keepdim=False, *, out=None):
output = _with_reduction_scalar(functools.partial(jnp.nanquantile, q=0.5, method='lower'), input, dim=dim, keepdim=keepdim).astype(input.dtype)
if dim is None:
return output
else:
index = _with_reduction_scalar(_get_median_index, input, dim, keepdim).astype(jnp.int64)
return output, index


def _get_median_index(x, axis=None, keepdims=False):
sorted_arg = jnp.argsort(x, axis=axis)
n = x.shape[axis] if axis is not None else x.size
Expand Down

0 comments on commit d7f304e

Please sign in to comment.