From 19cd2b0f2aedbc58ba147a29e506ce2283e950ab Mon Sep 17 00:00:00 2001 From: David Huang Date: Tue, 24 Sep 2024 10:30:24 +0000 Subject: [PATCH] Support include_self in scatter_reduce --- experimental/torch_xla2/test/test_ops.py | 1 - .../torch_xla2/torch_xla2/ops/jaten.py | 23 ++++++++++++++++++- 2 files changed, 22 insertions(+), 2 deletions(-) diff --git a/experimental/torch_xla2/test/test_ops.py b/experimental/torch_xla2/test/test_ops.py index 8ddccccdcaf..415ce0413b6 100644 --- a/experimental/torch_xla2/test/test_ops.py +++ b/experimental/torch_xla2/test/test_ops.py @@ -143,7 +143,6 @@ "resize_as_", "rot90", "rsub", - "scatter_reduce", "searchsorted", "special.airy_ai", "special.scaled_modified_bessel_k0", diff --git a/experimental/torch_xla2/torch_xla2/ops/jaten.py b/experimental/torch_xla2/torch_xla2/ops/jaten.py index 54adbd30e65..ce5e146298f 100644 --- a/experimental/torch_xla2/torch_xla2/ops/jaten.py +++ b/experimental/torch_xla2/torch_xla2/ops/jaten.py @@ -1455,12 +1455,33 @@ def _aten_scatter_reduce(input, dim, index, src, reduce, *, include_self=True): dtype = _torch_binary_scalar_type(src, input) src = jnp.array(src, dtype=dtype) input_indexes, source_indexes = _scatter_index(dim, index) + # "Zero out" target elements when not included + if not include_self: + if reduce in ["sum", "mean"]: + base_input = jnp.zeros_like(src) + elif reduce == "prod": + base_input = jnp.ones_like(src) + elif reduce == "amax": + base_input = jnp.full_like(src, -jnp.inf) + else: # amin + base_input = jnp.full_like(src, jnp.inf) + input = input.at[input_indexes].set(base_input[source_indexes]) + if reduce == "sum" or reduce == "add": return input.at[input_indexes].add(src[source_indexes]) elif reduce == "prod" or reduce == "multiply": return input.at[input_indexes].multiply(src[source_indexes]) elif reduce == "mean": - return input.at[input_indexes].add(src[source_indexes]) + if include_self: + count = jnp.ones_like(input) + else: + count = jnp.zeros_like(input) + count = count.at[input_indexes].add(jnp.ones_like(src)[source_indexes]) + count = jnp.clip(count, min=1) + mean = input.at[input_indexes].add(src[source_indexes]) + if _is_int(input): + return mean // count + return mean / count elif reduce == "amax": return input.at[input_indexes].max(src[source_indexes]) elif reduce == "amin":