Skip to content

Commit

Permalink
add histogram .. linalg.eigvals (#8002)
Browse files Browse the repository at this point in the history
  • Loading branch information
ManfeiBai committed Sep 13, 2024
1 parent 3a9f53f commit 24bc4b2
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 8 deletions.
9 changes: 2 additions & 7 deletions experimental/torch_xla2/test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
"complex",
"cummax",
"cummin",
"cumsum",
"diag_embed",
"diagflat",
"diagonal_copy",
Expand Down Expand Up @@ -64,22 +63,18 @@
"geometric",
"geqrf",
"grid_sampler_2d",
"histogram",
"histogramdd",
"hypot",
"igamma",
"histogram", # hard op: AssertionError: Tensor-likes are not close!
"histogramdd", # TypeError: histogram requires ndarray or scalar arguments, got <class 'list'> at position 1.
"igammac",
"index_reduce",
"kthvalue",
"lgamma",
"linalg.cholesky",
"linalg.cholesky_ex",
"linalg.cond",
"linalg.cross",
"linalg.det",
"linalg.eig",
"linalg.eigh",
"linalg.eigvals",
"linalg.eigvalsh",
"linalg.householder_product",
"linalg.inv",
Expand Down
17 changes: 16 additions & 1 deletion experimental/torch_xla2/torch_xla2/ops/jaten.py
Original file line number Diff line number Diff line change
Expand Up @@ -1989,6 +1989,7 @@ def _aten_hardtanh(input, min_val=-1, max_val=1, inplace=False):
# aten.histc
@op(torch.ops.aten.histc)
def _aten_histc(input, bins=100, min=0, max=0):
# TODO(@manfei): this function might cause some uncertainty
if min==0 and max==0:
if isinstance(input, jnp.ndarray) and input.size == 0:
min = 0
Expand All @@ -1997,11 +1998,25 @@ def _aten_histc(input, bins=100, min=0, max=0):
min = jnp.min(input)
max = jnp.max(input)
range_value = (min, max)
print("range_value: ", range_value)
hist, bin_edges = jnp.histogram(input, bins=bins, range=range_value, weights=None, density=None)
return hist


@op(torch.ops.aten.hypot)
def _aten_hypot(input, other):
return jnp.hypot(input, other)


@op(torch.ops.aten.igamma)
def _aten_igamma(input, other):
return jax.scipy.special.gammainc(input, other)


@op(torch.ops.aten.linalg_eig)
def _aten_linalg_eig(A):
return jax.numpy.linalg.eig(A)


# aten.lcm
@op(torch.ops.aten.lcm)
def _aten_lcm(input, other):
Expand Down

0 comments on commit 24bc4b2

Please sign in to comment.