From 24bc4b2a530526f994ca0e20efa0fb2ca9c31a7b Mon Sep 17 00:00:00 2001 From: Manfei <41607353+ManfeiBai@users.noreply.github.com> Date: Fri, 13 Sep 2024 12:04:26 -0700 Subject: [PATCH] add `histogram` .. `linalg.eigvals` (#8002) --- experimental/torch_xla2/test/test_ops.py | 9 ++------- experimental/torch_xla2/torch_xla2/ops/jaten.py | 17 ++++++++++++++++- 2 files changed, 18 insertions(+), 8 deletions(-) diff --git a/experimental/torch_xla2/test/test_ops.py b/experimental/torch_xla2/test/test_ops.py index 68d31136bba..01400c53b39 100644 --- a/experimental/torch_xla2/test/test_ops.py +++ b/experimental/torch_xla2/test/test_ops.py @@ -28,7 +28,6 @@ "complex", "cummax", "cummin", - "cumsum", "diag_embed", "diagflat", "diagonal_copy", @@ -64,10 +63,8 @@ "geometric", "geqrf", "grid_sampler_2d", - "histogram", - "histogramdd", - "hypot", - "igamma", + "histogram", # hard op: AssertionError: Tensor-likes are not close! + "histogramdd", # TypeError: histogram requires ndarray or scalar arguments, got at position 1. "igammac", "index_reduce", "kthvalue", @@ -75,11 +72,9 @@ "linalg.cholesky", "linalg.cholesky_ex", "linalg.cond", - "linalg.cross", "linalg.det", "linalg.eig", "linalg.eigh", - "linalg.eigvals", "linalg.eigvalsh", "linalg.householder_product", "linalg.inv", diff --git a/experimental/torch_xla2/torch_xla2/ops/jaten.py b/experimental/torch_xla2/torch_xla2/ops/jaten.py index e36c60e4767..03a2da7b380 100644 --- a/experimental/torch_xla2/torch_xla2/ops/jaten.py +++ b/experimental/torch_xla2/torch_xla2/ops/jaten.py @@ -1989,6 +1989,7 @@ def _aten_hardtanh(input, min_val=-1, max_val=1, inplace=False): # aten.histc @op(torch.ops.aten.histc) def _aten_histc(input, bins=100, min=0, max=0): + # TODO(@manfei): this function might cause some uncertainty if min==0 and max==0: if isinstance(input, jnp.ndarray) and input.size == 0: min = 0 @@ -1997,11 +1998,25 @@ def _aten_histc(input, bins=100, min=0, max=0): min = jnp.min(input) max = jnp.max(input) range_value = (min, max) - print("range_value: ", range_value) hist, bin_edges = jnp.histogram(input, bins=bins, range=range_value, weights=None, density=None) return hist +@op(torch.ops.aten.hypot) +def _aten_hypot(input, other): + return jnp.hypot(input, other) + + +@op(torch.ops.aten.igamma) +def _aten_igamma(input, other): + return jax.scipy.special.gammainc(input, other) + + +@op(torch.ops.aten.linalg_eig) +def _aten_linalg_eig(A): + return jax.numpy.linalg.eig(A) + + # aten.lcm @op(torch.ops.aten.lcm) def _aten_lcm(input, other):