Skip to content

Commit

Permalink
Added the support for max_unpool1d, max_unpool2d, and max_unpool3d #7524
Browse files Browse the repository at this point in the history
 (#8084)

Co-authored-by: Hossein Sarshar <hosseins@google.com>
  • Loading branch information
hosseinsarshar and hosseinsarshar committed Sep 27, 2024
1 parent 597bb29 commit 5f3c983
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 3 deletions.
3 changes: 0 additions & 3 deletions experimental/torch_xla2/test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,9 +107,6 @@
"nn.functional.max_pool1d",
"nn.functional.max_pool2d",
"nn.functional.max_pool3d",
"nn.functional.max_unpool1d",
"nn.functional.max_unpool2d",
"nn.functional.max_unpool3d",
"nn.functional.multi_head_attention_forward",
"nn.functional.multi_margin_loss",
"nn.functional.multilabel_margin_loss",
Expand Down
37 changes: 37 additions & 0 deletions experimental/torch_xla2/torch_xla2/ops/jaten.py
Original file line number Diff line number Diff line change
Expand Up @@ -4072,3 +4072,40 @@ def _aten__fft_c2r(self, dim, normalization, last_dim_size):
else:
s = None
return jnp.fft.irfftn(self, norm=norm, axes=dim, s=s)


@op(torch.ops.aten.max_unpool3d)
def _aten_max_unpool3d(input, indices, output_size, stride=None, padding=0):
if output_size is None:
raise ValueError("output_size value is not set correctly. It cannot be None or empty.")

output_size = [input.shape[0], input.shape[1]] + output_size
output = jnp.zeros(output_size, dtype=input.dtype)

for idx in np.ndindex(input.shape):
max_index = indices[idx]
spatial_dims = output_size[2:] # (D, H, W)
unpooled_spatial_idx = np.unravel_index(max_index, spatial_dims)
full_idx = idx[:2] + unpooled_spatial_idx
output = output.at[full_idx].set(input[idx])

return output

@op(torch.ops.aten.max_unpool2d)
def _aten_max_unpool2d(input, indices, output_size, stride=None, padding=0):
if output_size is None:
raise ValueError("output_size value is not set correctly. It cannot be None or empty.")

output_size = [input.shape[0], input.shape[1]] + output_size

output = jnp.zeros(output_size, dtype=input.dtype)

for idx in np.ndindex(input.shape):
max_index = indices[idx]
spatial_dims = output_size[2:]
unpooled_spatial_idx = np.unravel_index(max_index, spatial_dims)
full_idx = idx[:2] + unpooled_spatial_idx
output = output.at[full_idx].set(input[idx])

return output

0 comments on commit 5f3c983

Please sign in to comment.