diff --git a/experimental/torch_xla2/torch_xla2/ops/jaten.py b/experimental/torch_xla2/torch_xla2/ops/jaten.py index e9ce412379e..4181f11e3cf 100644 --- a/experimental/torch_xla2/torch_xla2/ops/jaten.py +++ b/experimental/torch_xla2/torch_xla2/ops/jaten.py @@ -4074,8 +4074,9 @@ def _aten__fft_c2r(self, dim, normalization, last_dim_size): return jnp.fft.irfftn(self, norm=norm, axes=dim, s=s) +@op(torch.ops.aten.max_unpool2d) @op(torch.ops.aten.max_unpool3d) -def _aten_max_unpool3d(input, indices, output_size, stride=None, padding=0): +def _aten_max_unpoolxd(input, indices, output_size, stride=None, padding=0): if output_size is None: raise ValueError("output_size value is not set correctly. It cannot be None or empty.") @@ -4091,21 +4092,3 @@ def _aten_max_unpool3d(input, indices, output_size, stride=None, padding=0): return output -@op(torch.ops.aten.max_unpool2d) -def _aten_max_unpool2d(input, indices, output_size, stride=None, padding=0): - if output_size is None: - raise ValueError("output_size value is not set correctly. It cannot be None or empty.") - - output_size = [input.shape[0], input.shape[1]] + output_size - - output = jnp.zeros(output_size, dtype=input.dtype) - - for idx in np.ndindex(input.shape): - max_index = indices[idx] - spatial_dims = output_size[2:] - unpooled_spatial_idx = np.unravel_index(max_index, spatial_dims) - full_idx = idx[:2] + unpooled_spatial_idx - output = output.at[full_idx].set(input[idx]) - - return output -