Skip to content

Commit

Permalink
Fixed Masked norm op implementation (#7985)
Browse files Browse the repository at this point in the history
  • Loading branch information
anishfish2 committed Sep 10, 2024
1 parent e3f6e9e commit 5c4bac5
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 11 deletions.
1 change: 0 additions & 1 deletion experimental/torch_xla2/test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,6 @@
"masked.logsumexp",
"masked.mean",
"masked.median",
"masked.norm",
"masked.normalize",
"masked.prod",
"masked_scatter",
Expand Down
25 changes: 15 additions & 10 deletions experimental/torch_xla2/torch_xla2/ops/jaten.py
Original file line number Diff line number Diff line change
Expand Up @@ -1062,34 +1062,39 @@ def _aten_linalg_vector_norm(self, ord=2, dim=None, keepdim=False, dtype=None):
The tensor containing the calculated vector norms.
"""

if ord not in {2, float("inf"), float("-inf"), "fro"}:
if ord not in {2, float("inf"), float("-inf"), "fro"} and not isinstance(ord, (int, float)):
raise ValueError(
f"Unsupported ord value: {ord}. Supported values are 2, inf, -inf, and"
" 'fro'."
)

# Special cases (for efficiency and clarity)
if ord == 2: # Euclidean norm
result = jnp.sqrt(jnp.sum(jnp.abs(self) ** 2, axis=dim, keepdims=keepdim))
if ord == 0:
if self.shape == ():
result = jnp.array(float(self != 0))
else:
result = _with_reduction_scalar(jnp.sum, jnp.where(self != 0, 1, 0), dim, keepdim)

elif ord == 2: # Euclidean norm
result = jnp.sqrt(_with_reduction_scalar(jnp.sum, jnp.abs(self) ** 2, dim, keepdim))

elif ord == float("inf"):
result = jnp.max(jnp.abs(self), axis=dim, keepdims=keepdim)
result = _with_reduction_scalar(jnp.max, jnp.abs(self), dim, keepdim)

elif ord == float("-inf"):
result = jnp.min(jnp.abs(self), axis=dim, keepdims=keepdim)
result = _with_reduction_scalar(jnp.min, jnp.abs(self), dim, keepdim)

elif ord == "fro": # Frobenius norm
result = jnp.sqrt(jnp.sum(jnp.abs(self) ** 2, axis=dim, keepdims=keepdim))
result = jnp.sqrt(_with_reduction_scalar(jnp.sum, jnp.abs(self) ** 2, dim, keepdim))

else: # General case (e.g., ord = 1, ord = 3)
result = jnp.sum(jnp.abs(self) ** ord, axis=dim, keepdims=keepdim) ** (
result = _with_reduction_scalar(jnp.sum, jnp.abs(self) ** ord, dim, keepdim) ** (
1.0 / ord
)

# (Optional) dtype conversion
if dtype is not None:
result = result.astype(mappings.t2j_dtype(dtype))

result = jnp.astype(result, self.dtype)
return result


Expand Down

0 comments on commit 5c4bac5

Please sign in to comment.