Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[torch_xla2] Fix op_info test for scatter_reduce #8058

Merged
merged 1 commit into from
Sep 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion experimental/torch_xla2/test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,6 @@
"resize_as_",
"rot90",
"rsub",
"scatter_reduce",
"searchsorted",
"special.airy_ai",
"special.scaled_modified_bessel_k0",
Expand Down
23 changes: 22 additions & 1 deletion experimental/torch_xla2/torch_xla2/ops/jaten.py
Original file line number Diff line number Diff line change
Expand Up @@ -1455,12 +1455,33 @@ def _aten_scatter_reduce(input, dim, index, src, reduce, *, include_self=True):
dtype = _torch_binary_scalar_type(src, input)
src = jnp.array(src, dtype=dtype)
input_indexes, source_indexes = _scatter_index(dim, index)
# "Zero out" target elements when not included
if not include_self:
if reduce in ["sum", "mean"]:
base_input = jnp.zeros_like(src)
elif reduce == "prod":
base_input = jnp.ones_like(src)
elif reduce == "amax":
base_input = jnp.full_like(src, -jnp.inf)
else: # amin
base_input = jnp.full_like(src, jnp.inf)
input = input.at[input_indexes].set(base_input[source_indexes])

if reduce == "sum" or reduce == "add":
return input.at[input_indexes].add(src[source_indexes])
elif reduce == "prod" or reduce == "multiply":
return input.at[input_indexes].multiply(src[source_indexes])
elif reduce == "mean":
return input.at[input_indexes].add(src[source_indexes])
if include_self:
count = jnp.ones_like(input)
else:
count = jnp.zeros_like(input)
count = count.at[input_indexes].add(jnp.ones_like(src)[source_indexes])
count = jnp.clip(count, min=1)
mean = input.at[input_indexes].add(src[source_indexes])
if _is_int(input):
return mean // count
return mean / count
elif reduce == "amax":
return input.at[input_indexes].max(src[source_indexes])
elif reduce == "amin":
Expand Down
Loading