Skip to content

Commit

Permalink
[torch_xla2] Fix op_info test for scatter_reduce (#8058)
Browse files Browse the repository at this point in the history
  • Loading branch information
dvhg committed Sep 25, 2024
1 parent 22cf197 commit 5b6e8cf
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 2 deletions.
1 change: 0 additions & 1 deletion experimental/torch_xla2/test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,6 @@
"prod",
"put",
"rsub",
"scatter_reduce",
"searchsorted",
"special.airy_ai",
"special.scaled_modified_bessel_k0",
Expand Down
23 changes: 22 additions & 1 deletion experimental/torch_xla2/torch_xla2/ops/jaten.py
Original file line number Diff line number Diff line change
Expand Up @@ -1471,12 +1471,33 @@ def _aten_scatter_reduce(input, dim, index, src, reduce, *, include_self=True):
dtype = _torch_binary_scalar_type(src, input)
src = jnp.array(src, dtype=dtype)
input_indexes, source_indexes = _scatter_index(dim, index)
# "Zero out" target elements when not included
if not include_self:
if reduce in ["sum", "mean"]:
base_input = jnp.zeros_like(src)
elif reduce == "prod":
base_input = jnp.ones_like(src)
elif reduce == "amax":
base_input = jnp.full_like(src, -jnp.inf)
else: # amin
base_input = jnp.full_like(src, jnp.inf)
input = input.at[input_indexes].set(base_input[source_indexes])

if reduce == "sum" or reduce == "add":
return input.at[input_indexes].add(src[source_indexes])
elif reduce == "prod" or reduce == "multiply":
return input.at[input_indexes].multiply(src[source_indexes])
elif reduce == "mean":
return input.at[input_indexes].add(src[source_indexes])
if include_self:
count = jnp.ones_like(input)
else:
count = jnp.zeros_like(input)
count = count.at[input_indexes].add(jnp.ones_like(src)[source_indexes])
count = jnp.clip(count, min=1)
mean = input.at[input_indexes].add(src[source_indexes])
if _is_int(input):
return mean // count
return mean / count
elif reduce == "amax":
return input.at[input_indexes].max(src[source_indexes])
elif reduce == "amin":
Expand Down

0 comments on commit 5b6e8cf

Please sign in to comment.