Skip to content

Commit

Permalink
Fix torch.multinomial
Browse files Browse the repository at this point in the history
  • Loading branch information
Simon Teo committed Sep 28, 2024
1 parent ecc0f5a commit a836357
Showing 1 changed file with 13 additions and 0 deletions.
13 changes: 13 additions & 0 deletions experimental/torch_xla2/torch_xla2/ops/jaten.py
Original file line number Diff line number Diff line change
Expand Up @@ -3900,6 +3900,19 @@ def f(k, carry):
return vectorized(self, n.astype(jnp.int64))


@op(torch.ops.aten.multinomial, needs_env=True)
def _aten_multinomial(input, num_samples, replacement=False, *, generator=None, out=None, env=None):
assert num_samples <= input.shape[-1] or replacement, "cannot take a larger sample than population when replacement=False"
assert jnp.all(input >= 0), "inputs must be non-negative"
key = env.get_and_rotate_prng_key(generator)
if input.ndim == 1:
assert jnp.sum(input) > 0, "rows of input must have non-zero sum"
return jax.random.choice(key, input.shape[-1], (num_samples,), replace=replacement, p=input)
else:
assert jnp.all(jnp.sum(input, axis=1) > 0), "rows of input must have non-zero sum"
return jnp.array([jax.random.choice(key, input.shape[-1], (num_samples,), replace=replacement, p=input[i, :]) for i in range(input.shape[0])])


@op(torch.ops.aten.narrow)
@op(torch.ops.aten.narrow_copy)
def _aten_narrow(input, dim, start, length):
Expand Down

0 comments on commit a836357

Please sign in to comment.