diff --git a/experimental/torch_xla2/torch_xla2/ops/jaten.py b/experimental/torch_xla2/torch_xla2/ops/jaten.py index c9fb53e28cc..9195cb23759 100644 --- a/experimental/torch_xla2/torch_xla2/ops/jaten.py +++ b/experimental/torch_xla2/torch_xla2/ops/jaten.py @@ -2197,6 +2197,13 @@ def _aten_hypot(input, other): def _aten_igamma(input, other): return jax.scipy.special.gammainc(input, other) +@op(torch.ops.aten.lgamma) +def _aten_lgamma(input, *, out=None): + return jax.scipy.special.gammaln(input).astype(jnp.float32) + +@op(torch.ops.aten.mvlgamma) +def _aten_mvlgamma(input, p, *, out=None): + return jax.scipy.special.multigammaln(input, d) @op(torch.ops.aten.linalg_eig) def _aten_linalg_eig(A):