diff --git a/experimental/torch_xla2/test/test_ops.py b/experimental/torch_xla2/test/test_ops.py index ac23121824e..f1c31594f81 100644 --- a/experimental/torch_xla2/test/test_ops.py +++ b/experimental/torch_xla2/test/test_ops.py @@ -127,7 +127,6 @@ "masked.logsumexp", "masked.mean", "masked.median", - "masked.norm", "masked.normalize", "masked.prod", "masked_scatter", diff --git a/experimental/torch_xla2/torch_xla2/ops/jaten.py b/experimental/torch_xla2/torch_xla2/ops/jaten.py index 34eee633ff1..ade10ceeca2 100644 --- a/experimental/torch_xla2/torch_xla2/ops/jaten.py +++ b/experimental/torch_xla2/torch_xla2/ops/jaten.py @@ -1062,34 +1062,39 @@ def _aten_linalg_vector_norm(self, ord=2, dim=None, keepdim=False, dtype=None): The tensor containing the calculated vector norms. """ - if ord not in {2, float("inf"), float("-inf"), "fro"}: + if ord not in {2, float("inf"), float("-inf"), "fro"} and not isinstance(ord, (int, float)): raise ValueError( f"Unsupported ord value: {ord}. Supported values are 2, inf, -inf, and" " 'fro'." ) - + # Special cases (for efficiency and clarity) - if ord == 2: # Euclidean norm - result = jnp.sqrt(jnp.sum(jnp.abs(self) ** 2, axis=dim, keepdims=keepdim)) + if ord == 0: + if self.shape == (): + result = jnp.array(float(self != 0)) + else: + result = _with_reduction_scalar(jnp.sum, jnp.where(self != 0, 1, 0), dim, keepdim) + + elif ord == 2: # Euclidean norm + result = jnp.sqrt(_with_reduction_scalar(jnp.sum, jnp.abs(self) ** 2, dim, keepdim)) elif ord == float("inf"): - result = jnp.max(jnp.abs(self), axis=dim, keepdims=keepdim) + result = _with_reduction_scalar(jnp.max, jnp.abs(self), dim, keepdim) elif ord == float("-inf"): - result = jnp.min(jnp.abs(self), axis=dim, keepdims=keepdim) + result = _with_reduction_scalar(jnp.min, jnp.abs(self), dim, keepdim) elif ord == "fro": # Frobenius norm - result = jnp.sqrt(jnp.sum(jnp.abs(self) ** 2, axis=dim, keepdims=keepdim)) + result = jnp.sqrt(_with_reduction_scalar(jnp.sum, jnp.abs(self) ** 2, dim, keepdim)) else: # General case (e.g., ord = 1, ord = 3) - result = jnp.sum(jnp.abs(self) ** ord, axis=dim, keepdims=keepdim) ** ( + result = _with_reduction_scalar(jnp.sum, jnp.abs(self) ** ord, dim, keepdim) ** ( 1.0 / ord ) # (Optional) dtype conversion if dtype is not None: - result = result.astype(mappings.t2j_dtype(dtype)) - + result = jnp.astype(result, self.dtype) return result