Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix lgamma, mvlgamma, multinomial, and nanmedian (#7513) #8095

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 1 addition & 4 deletions experimental/torch_xla2/test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
"igammac",
"index_reduce",
"kthvalue",
"lgamma",
"linalg.cholesky",
"linalg.cholesky_ex",
"linalg.det",
Expand Down Expand Up @@ -69,9 +68,6 @@
"lu_unpack",
"masked.median",
"max_pool2d_with_indices_backward",
"multinomial",
"mvlgamma",
"nanmedian",
"new_empty_strided",
"nextafter",
"nn.functional.adaptive_avg_pool3d",
Expand Down Expand Up @@ -169,6 +165,7 @@
'rand',
'rand_like',
'uniform',
'multinomial',
# Dropout is not deterministic https://pytorch.org/docs/stable/generated/torch.nn.functional.feature_alpha_dropout.html
'nn.functional.feature_alpha_dropout',
}
Expand Down
31 changes: 31 additions & 0 deletions experimental/torch_xla2/torch_xla2/ops/jaten.py
Original file line number Diff line number Diff line change
Expand Up @@ -2232,6 +2232,13 @@ def _aten_hypot(input, other):
def _aten_igamma(input, other):
return jax.scipy.special.gammainc(input, other)

@op(torch.ops.aten.lgamma)
def _aten_lgamma(input, *, out=None):
return jax.scipy.special.gammaln(input).astype(jnp.float32)

@op(torch.ops.aten.mvlgamma)
def _aten_mvlgamma(input, p, *, out=None):
return jax.scipy.special.multigammaln(input, d)

@op(torch.ops.aten.linalg_eig)
def _aten_linalg_eig(A):
Expand Down Expand Up @@ -3935,6 +3942,19 @@ def f(k, carry):
return vectorized(self, n.astype(jnp.int64))


@op(torch.ops.aten.multinomial, needs_env=True)
def _aten_multinomial(input, num_samples, replacement=False, *, generator=None, out=None, env=None):
assert num_samples <= input.shape[-1] or replacement, "cannot take a larger sample than population when replacement=False"
assert jnp.all(input >= 0), "inputs must be non-negative"
key = env.get_and_rotate_prng_key(generator)
if input.ndim == 1:
assert jnp.sum(input) > 0, "rows of input must have non-zero sum"
return jax.random.choice(key, input.shape[-1], (num_samples,), replace=replacement, p=input)
else:
assert jnp.all(jnp.sum(input, axis=1) > 0), "rows of input must have non-zero sum"
return jnp.array([jax.random.choice(key, input.shape[-1], (num_samples,), replace=replacement, p=input[i, :]) for i in range(input.shape[0])])


@op(torch.ops.aten.narrow)
@op(torch.ops.aten.narrow_copy)
def _aten_narrow(input, dim, start, length):
Expand Down Expand Up @@ -4047,6 +4067,17 @@ def _aten_median(self, dim=None, keepdim=False):
index = _with_reduction_scalar(_get_median_index, self, dim, keepdim).astype(jnp.int64)
return output, index


@op(torch.ops.aten.nanmedian)
def _aten_nanmedian(input, dim=None, keepdim=False, *, out=None):
output = _with_reduction_scalar(functools.partial(jnp.nanquantile, q=0.5, method='lower'), input, dim=dim, keepdim=keepdim).astype(input.dtype)
if dim is None:
return output
else:
index = _with_reduction_scalar(_get_median_index, input, dim, keepdim).astype(jnp.int64)
return output, index


def _get_median_index(x, axis=None, keepdims=False):
sorted_arg = jnp.argsort(x, axis=axis)
n = x.shape[axis] if axis is not None else x.size
Expand Down