diff --git a/experimental/torch_xla2/test/test_ops.py b/experimental/torch_xla2/test/test_ops.py index 8ddccccdcaf..858523a0ff3 100644 --- a/experimental/torch_xla2/test/test_ops.py +++ b/experimental/torch_xla2/test/test_ops.py @@ -113,9 +113,6 @@ "nn.functional.max_pool1d", "nn.functional.max_pool2d", "nn.functional.max_pool3d", - "nn.functional.max_unpool1d", - "nn.functional.max_unpool2d", - "nn.functional.max_unpool3d", "nn.functional.multi_head_attention_forward", "nn.functional.multi_margin_loss", "nn.functional.multilabel_margin_loss", diff --git a/experimental/torch_xla2/torch_xla2/ops/jaten.py b/experimental/torch_xla2/torch_xla2/ops/jaten.py index 54adbd30e65..4948519cbfa 100644 --- a/experimental/torch_xla2/torch_xla2/ops/jaten.py +++ b/experimental/torch_xla2/torch_xla2/ops/jaten.py @@ -4026,3 +4026,40 @@ def _aten__fft_c2r(self, dim, normalization, last_dim_size): else: s = None return jnp.fft.irfftn(self, norm=norm, axes=dim, s=s) + + +@op(torch.ops.aten.max_unpool3d) +def _aten_max_unpool3d(input, indices, output_size, stride=None, padding=0): + if output_size is None: + raise ValueError("output_size value is not set correctly. It cannot be None or empty.") + + output_size = [input.shape[0], input.shape[1]] + output_size + output = jnp.zeros(output_size, dtype=input.dtype) + + for idx in np.ndindex(input.shape): + max_index = indices[idx] + spatial_dims = output_size[2:] # (D, H, W) + unpooled_spatial_idx = np.unravel_index(max_index, spatial_dims) + full_idx = idx[:2] + unpooled_spatial_idx + output = output.at[full_idx].set(input[idx]) + + return output + +@op(torch.ops.aten.max_unpool2d) +def _aten_max_unpool2d(input, indices, output_size, stride=None, padding=0): + if output_size is None: + raise ValueError("output_size value is not set correctly. It cannot be None or empty.") + + output_size = [input.shape[0], input.shape[1]] + output_size + + output = jnp.zeros(output_size, dtype=input.dtype) + + for idx in np.ndindex(input.shape): + max_index = indices[idx] + spatial_dims = output_size[2:] + unpooled_spatial_idx = np.unravel_index(max_index, spatial_dims) + full_idx = idx[:2] + unpooled_spatial_idx + output = output.at[full_idx].set(input[idx]) + + return output +