Skip to content

Commit

Permalink
Update jaten.py
Browse files Browse the repository at this point in the history
  • Loading branch information
ManfeiBai committed Sep 15, 2024
1 parent 711cc55 commit 56a7f84
Showing 1 changed file with 6 additions and 0 deletions.
6 changes: 6 additions & 0 deletions experimental/torch_xla2/torch_xla2/ops/jaten.py
Original file line number Diff line number Diff line change
Expand Up @@ -3884,6 +3884,12 @@ def _get_median_index(x, axis=None, keepdims=False):
median_index = jnp.expand_dims(median_index, axis)
return median_index


@op(torch.ops.aten.nanmedian.dim)
def _aten_nanmedian(mask_input, dim_=None, keepdim=None):
return jnp.nanmedian(mask_input, axis=dim_, keepdims=keepdim)


@op(torch.ops.aten.triangular_solve)
def _aten_triangular_solve(b, a, upper=True, transpose=False, unittriangular=False):
return (jax.lax.linalg.triangular_solve(a, b, left_side=True, lower=not upper, transpose_a=transpose, unit_diagonal=unittriangular), a)

0 comments on commit 56a7f84

Please sign in to comment.