diff --git a/experimental/torch_xla2/test/test_ops.py b/experimental/torch_xla2/test/test_ops.py index 2b88ae324fa..6d07c46ab2b 100644 --- a/experimental/torch_xla2/test/test_ops.py +++ b/experimental/torch_xla2/test/test_ops.py @@ -31,7 +31,6 @@ "diagonal_scatter", "diff", "digamma", - "dist", "erfc", "erfinv", "expand", diff --git a/experimental/torch_xla2/torch_xla2/ops/jaten.py b/experimental/torch_xla2/torch_xla2/ops/jaten.py index 5e094d39927..a3dd369afd9 100644 --- a/experimental/torch_xla2/torch_xla2/ops/jaten.py +++ b/experimental/torch_xla2/torch_xla2/ops/jaten.py @@ -331,6 +331,10 @@ def _aten_div(x, y, rounding_mode=""): def _aten_true_divide(x, y): return x / y +@op(torch.ops.aten.dist) +def _aten_dist(input, other, p=2): + diff = jnp.abs(jnp.subtract(input, other)) + return _aten_linalg_vector_norm(diff, ord=p) @op(torch.ops.aten.bmm) def _aten_bmm(x, y):