From 203680600a92564ac86097a6fd65d88676f67bd5 Mon Sep 17 00:00:00 2001 From: Simon Teo Date: Sat, 28 Sep 2024 23:28:08 +0800 Subject: [PATCH] Fix torch.mvlgamma and torch.lgamma --- experimental/torch_xla2/torch_xla2/ops/jaten.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/experimental/torch_xla2/torch_xla2/ops/jaten.py b/experimental/torch_xla2/torch_xla2/ops/jaten.py index c9fb53e28cc..9195cb23759 100644 --- a/experimental/torch_xla2/torch_xla2/ops/jaten.py +++ b/experimental/torch_xla2/torch_xla2/ops/jaten.py @@ -2197,6 +2197,13 @@ def _aten_hypot(input, other): def _aten_igamma(input, other): return jax.scipy.special.gammainc(input, other) +@op(torch.ops.aten.lgamma) +def _aten_lgamma(input, *, out=None): + return jax.scipy.special.gammaln(input).astype(jnp.float32) + +@op(torch.ops.aten.mvlgamma) +def _aten_mvlgamma(input, p, *, out=None): + return jax.scipy.special.multigammaln(input, d) @op(torch.ops.aten.linalg_eig) def _aten_linalg_eig(A):