Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reduced the max_unpoolxd logic into one function #7524 #8085

Merged
merged 7 commits into from
Sep 27, 2024
21 changes: 2 additions & 19 deletions experimental/torch_xla2/torch_xla2/ops/jaten.py
Original file line number Diff line number Diff line change
Expand Up @@ -4074,8 +4074,9 @@ def _aten__fft_c2r(self, dim, normalization, last_dim_size):
return jnp.fft.irfftn(self, norm=norm, axes=dim, s=s)


@op(torch.ops.aten.max_unpool2d)
@op(torch.ops.aten.max_unpool3d)
def _aten_max_unpool3d(input, indices, output_size, stride=None, padding=0):
def _aten_max_unpoolxd(input, indices, output_size, stride=None, padding=0):
if output_size is None:
raise ValueError("output_size value is not set correctly. It cannot be None or empty.")

Expand All @@ -4091,21 +4092,3 @@ def _aten_max_unpool3d(input, indices, output_size, stride=None, padding=0):

return output

@op(torch.ops.aten.max_unpool2d)
def _aten_max_unpool2d(input, indices, output_size, stride=None, padding=0):
if output_size is None:
raise ValueError("output_size value is not set correctly. It cannot be None or empty.")

output_size = [input.shape[0], input.shape[1]] + output_size

output = jnp.zeros(output_size, dtype=input.dtype)

for idx in np.ndindex(input.shape):
max_index = indices[idx]
spatial_dims = output_size[2:]
unpooled_spatial_idx = np.unravel_index(max_index, spatial_dims)
full_idx = idx[:2] + unpooled_spatial_idx
output = output.at[full_idx].set(input[idx])

return output

Loading