Skip to content

Commit

Permalink
Reduced the max_unpoolxd logic into one function #7524 (#8085)
Browse files Browse the repository at this point in the history
Co-authored-by: Hossein Sarshar <hosseins@google.com>
  • Loading branch information
hosseinsarshar and hosseinsarshar committed Sep 27, 2024
1 parent 5f3c983 commit 45d0e22
Showing 1 changed file with 2 additions and 19 deletions.
21 changes: 2 additions & 19 deletions experimental/torch_xla2/torch_xla2/ops/jaten.py
Original file line number Diff line number Diff line change
Expand Up @@ -4074,8 +4074,9 @@ def _aten__fft_c2r(self, dim, normalization, last_dim_size):
return jnp.fft.irfftn(self, norm=norm, axes=dim, s=s)


@op(torch.ops.aten.max_unpool2d)
@op(torch.ops.aten.max_unpool3d)
def _aten_max_unpool3d(input, indices, output_size, stride=None, padding=0):
def _aten_max_unpoolxd(input, indices, output_size, stride=None, padding=0):
if output_size is None:
raise ValueError("output_size value is not set correctly. It cannot be None or empty.")

Expand All @@ -4091,21 +4092,3 @@ def _aten_max_unpool3d(input, indices, output_size, stride=None, padding=0):

return output

@op(torch.ops.aten.max_unpool2d)
def _aten_max_unpool2d(input, indices, output_size, stride=None, padding=0):
if output_size is None:
raise ValueError("output_size value is not set correctly. It cannot be None or empty.")

output_size = [input.shape[0], input.shape[1]] + output_size

output = jnp.zeros(output_size, dtype=input.dtype)

for idx in np.ndindex(input.shape):
max_index = indices[idx]
spatial_dims = output_size[2:]
unpooled_spatial_idx = np.unravel_index(max_index, spatial_dims)
full_idx = idx[:2] + unpooled_spatial_idx
output = output.at[full_idx].set(input[idx])

return output

0 comments on commit 45d0e22

Please sign in to comment.