From 5b6e8cf1bf113355e286055b0a5a70bad72082f1 Mon Sep 17 00:00:00 2001 From: David Huang Date: Tue, 24 Sep 2024 18:31:44 -0700 Subject: [PATCH] [torch_xla2] Fix op_info test for scatter_reduce (#8058) --- experimental/torch_xla2/test/test_ops.py | 1 - .../torch_xla2/torch_xla2/ops/jaten.py | 23 ++++++++++++++++++- 2 files changed, 22 insertions(+), 2 deletions(-) diff --git a/experimental/torch_xla2/test/test_ops.py b/experimental/torch_xla2/test/test_ops.py index 3654b9f689e..3185d996b26 100644 --- a/experimental/torch_xla2/test/test_ops.py +++ b/experimental/torch_xla2/test/test_ops.py @@ -137,7 +137,6 @@ "prod", "put", "rsub", - "scatter_reduce", "searchsorted", "special.airy_ai", "special.scaled_modified_bessel_k0", diff --git a/experimental/torch_xla2/torch_xla2/ops/jaten.py b/experimental/torch_xla2/torch_xla2/ops/jaten.py index 6c3accd7de2..90adbb9cc13 100644 --- a/experimental/torch_xla2/torch_xla2/ops/jaten.py +++ b/experimental/torch_xla2/torch_xla2/ops/jaten.py @@ -1471,12 +1471,33 @@ def _aten_scatter_reduce(input, dim, index, src, reduce, *, include_self=True): dtype = _torch_binary_scalar_type(src, input) src = jnp.array(src, dtype=dtype) input_indexes, source_indexes = _scatter_index(dim, index) + # "Zero out" target elements when not included + if not include_self: + if reduce in ["sum", "mean"]: + base_input = jnp.zeros_like(src) + elif reduce == "prod": + base_input = jnp.ones_like(src) + elif reduce == "amax": + base_input = jnp.full_like(src, -jnp.inf) + else: # amin + base_input = jnp.full_like(src, jnp.inf) + input = input.at[input_indexes].set(base_input[source_indexes]) + if reduce == "sum" or reduce == "add": return input.at[input_indexes].add(src[source_indexes]) elif reduce == "prod" or reduce == "multiply": return input.at[input_indexes].multiply(src[source_indexes]) elif reduce == "mean": - return input.at[input_indexes].add(src[source_indexes]) + if include_self: + count = jnp.ones_like(input) + else: + count = jnp.zeros_like(input) + count = count.at[input_indexes].add(jnp.ones_like(src)[source_indexes]) + count = jnp.clip(count, min=1) + mean = input.at[input_indexes].add(src[source_indexes]) + if _is_int(input): + return mean // count + return mean / count elif reduce == "amax": return input.at[input_indexes].max(src[source_indexes]) elif reduce == "amin":