Skip to content

Commit

Permalink
add dist
Browse files Browse the repository at this point in the history
  • Loading branch information
nupurbaghel committed Sep 16, 2024
1 parent 2a5897d commit 1119cbb
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 1 deletion.
1 change: 0 additions & 1 deletion experimental/torch_xla2/test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
"diagonal_scatter",
"diff",
"digamma",
"dist",
"erfc",
"erfinv",
"expand",
Expand Down
4 changes: 4 additions & 0 deletions experimental/torch_xla2/torch_xla2/ops/jaten.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,10 @@ def _aten_div(x, y, rounding_mode=""):
def _aten_true_divide(x, y):
return x / y

@op(torch.ops.aten.dist)
def _aten_dist(input, other, p=2):
diff = jnp.abs(jnp.subtract(input, other))
return _aten_linalg_vector_norm(diff, ord=p)

@op(torch.ops.aten.bmm)
def _aten_bmm(x, y):
Expand Down

0 comments on commit 1119cbb

Please sign in to comment.