From 820558e179f171b54f0cf2ea782cc2fca512ae4d Mon Sep 17 00:00:00 2001 From: Hossein Sarshar Date: Tue, 24 Sep 2024 20:06:11 -0400 Subject: [PATCH 1/6] added the max_pool2d function --- experimental/torch_xla2/test/test_ops.py | 2 +- .../torch_xla2/torch_xla2/ops/jaten.py | 25 +++++++++++++++++++ 2 files changed, 26 insertions(+), 1 deletion(-) diff --git a/experimental/torch_xla2/test/test_ops.py b/experimental/torch_xla2/test/test_ops.py index 8ddccccdcaf..806ce751943 100644 --- a/experimental/torch_xla2/test/test_ops.py +++ b/experimental/torch_xla2/test/test_ops.py @@ -111,7 +111,7 @@ "nn.functional.interpolate", "nn.functional.margin_ranking_loss", "nn.functional.max_pool1d", - "nn.functional.max_pool2d", + # "nn.functional.max_pool2d", "nn.functional.max_pool3d", "nn.functional.max_unpool1d", "nn.functional.max_unpool2d", diff --git a/experimental/torch_xla2/torch_xla2/ops/jaten.py b/experimental/torch_xla2/torch_xla2/ops/jaten.py index 54adbd30e65..db3e958893b 100644 --- a/experimental/torch_xla2/torch_xla2/ops/jaten.py +++ b/experimental/torch_xla2/torch_xla2/ops/jaten.py @@ -4026,3 +4026,28 @@ def _aten__fft_c2r(self, dim, normalization, last_dim_size): else: s = None return jnp.fft.irfftn(self, norm=norm, axes=dim, s=s) + +@op(torch.ops.aten.max_pool2d) +def _aten_max_pool2d(input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False): + if stride is None: + stride = kernel_size + + # Handle padding modes ('SAME' or 'VALID') conversion for PyTorch-like behavior. + if padding == 0: + padding_mode = 'VALID' + else: + padding_mode = 'SAME' + + # JAX doesn't support dilation in reduce_window, so we ignore it. + # JAX doesn't support ceil_mode, so we ignore it. + + window_shape = (kernel_size, kernel_size) + strides = (stride, stride) + + # Apply max pooling using jax.lax.reduce_window + pooled = jax.lax.reduce_window(input, -jnp.inf, jax.lax.max, + window_dimensions=window_shape, + window_strides=strides, + padding=padding_mode) + + return pooled From 2a6e0aa28d208ddddfb179e5f4c3edcb5f35814e Mon Sep 17 00:00:00 2001 From: Hossein Sarshar Date: Fri, 27 Sep 2024 14:29:22 -0400 Subject: [PATCH 2/6] added the lowering of max_unpool1d, max_unpool2d, and max_unpool3d #7524 --- experimental/torch_xla2/test/test_ops.py | 7 +- .../torch_xla2/torch_xla2/ops/jaten.py | 107 ++++++++++++++---- 2 files changed, 88 insertions(+), 26 deletions(-) diff --git a/experimental/torch_xla2/test/test_ops.py b/experimental/torch_xla2/test/test_ops.py index 806ce751943..79bef799931 100644 --- a/experimental/torch_xla2/test/test_ops.py +++ b/experimental/torch_xla2/test/test_ops.py @@ -111,11 +111,8 @@ "nn.functional.interpolate", "nn.functional.margin_ranking_loss", "nn.functional.max_pool1d", - # "nn.functional.max_pool2d", + "nn.functional.max_pool2d", "nn.functional.max_pool3d", - "nn.functional.max_unpool1d", - "nn.functional.max_unpool2d", - "nn.functional.max_unpool3d", "nn.functional.multi_head_attention_forward", "nn.functional.multi_margin_loss", "nn.functional.multilabel_margin_loss", @@ -282,7 +279,9 @@ def replace_values_below_threshold(self, torch_tensor, threshold): @ops(ops_to_test, allowed_dtypes=(torch.float32, torch.long)) def test_reference_eager(self, device, dtype, op): sample_inputs = op.sample_inputs(device, dtype) + print(f"{op=}") for sample_input in sample_inputs: + print("sample_input: ", sample_input) t = sample_input.input if isinstance(t, torch.Tensor) and t.is_sparse: continue diff --git a/experimental/torch_xla2/torch_xla2/ops/jaten.py b/experimental/torch_xla2/torch_xla2/ops/jaten.py index db3e958893b..80068b49af3 100644 --- a/experimental/torch_xla2/torch_xla2/ops/jaten.py +++ b/experimental/torch_xla2/torch_xla2/ops/jaten.py @@ -14,6 +14,10 @@ from torch_xla2.ops import op_base, mappings from torch_xla2 import interop +import collections +from itertools import repeat + + # Keys are OpOverload, value is a callable that takes # XLATensor2 all_ops = {} @@ -4027,27 +4031,86 @@ def _aten__fft_c2r(self, dim, normalization, last_dim_size): s = None return jnp.fft.irfftn(self, norm=norm, axes=dim, s=s) -@op(torch.ops.aten.max_pool2d) -def _aten_max_pool2d(input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False): - if stride is None: - stride = kernel_size - - # Handle padding modes ('SAME' or 'VALID') conversion for PyTorch-like behavior. - if padding == 0: - padding_mode = 'VALID' - else: - padding_mode = 'SAME' - - # JAX doesn't support dilation in reduce_window, so we ignore it. - # JAX doesn't support ceil_mode, so we ignore it. +# @op(torch.nn.functional.max_pool2d) +# def _aten_max_pool2d(self, input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False, return_indices=False): +# print('================== Hello World [_aten_max_pool2d] =======================') +# if stride is None: +# stride = kernel_size +# +# # Handle padding modes ('SAME' or 'VALID') conversion for PyTorch-like behavior. +# if padding == 0: +# padding_mode = 'VALID' +# else: +# padding_mode = 'SAME' +# +# # JAX doesn't support dilation in reduce_window, so we ignore it. +# # JAX doesn't support ceil_mode, so we ignore it. +# +# window_shape = (kernel_size, kernel_size) +# strides = (stride, stride) +# +# # Apply max pooling using jax.lax.reduce_window +# pooled = jax.lax.reduce_window(input, -jnp.inf, jax.lax.max, +# window_dimensions=window_shape, +# window_strides=strides, +# padding=padding_mode) +# +# return pooled + + + +@op(torch.ops.aten.max_unpool3d) +def _aten_max_unpool3d(input, indices, output_size, stride=None, padding=0): + if output_size is None: + raise ValueError("output_size value is not set correctly. It cannot be None or empty.") + + output_size = [input.shape[0], input.shape[1]] + output_size + + # Initialize an output array of zeros with the provided output_size + output = jnp.zeros(output_size, dtype=input.dtype) - window_shape = (kernel_size, kernel_size) - strides = (stride, stride) + # Use numpy.ndindex to iterate over all indices of the input tensor + for idx in np.ndindex(input.shape): + max_index = indices[idx] + + # Get the spatial dimensions of the output + spatial_dims = output_size[2:] # (D, H, W) + + # Unravel the flat index to multi-dimensional index + unpooled_spatial_idx = np.unravel_index(max_index, spatial_dims) + + # Combine batch, channel, and spatial indices + full_idx = idx[:2] + unpooled_spatial_idx + + output = output.at[full_idx].set(input[idx]) + + return output + +@op(torch.ops.aten.max_unpool2d) +def _aten_max_unpool2d(input, indices, output_size, stride=None, padding=0): + if output_size is None: + raise ValueError("output_size value is not set correctly. It cannot be None or empty.") + + output_size = [input.shape[0], input.shape[1]] + output_size + + # Initialize the output array with zeros + output = jnp.zeros(output_size, dtype=input.dtype) + + # Use numpy.ndindex to iterate over all indices of the input tensor + for idx in np.ndindex(input.shape): + max_index = indices[idx] + + # Get the spatial dimensions of the output (H, W) + spatial_dims = output_size[2:] + + # Unravel the flat index to multi-dimensional index for 2D + unpooled_spatial_idx = np.unravel_index(max_index, spatial_dims) + + # Combine batch, channel, and spatial indices + full_idx = idx[:2] + unpooled_spatial_idx + + # Set the value in the output array at the corresponding location + output = output.at[full_idx].set(input[idx]) + + return output - # Apply max pooling using jax.lax.reduce_window - pooled = jax.lax.reduce_window(input, -jnp.inf, jax.lax.max, - window_dimensions=window_shape, - window_strides=strides, - padding=padding_mode) - - return pooled From 8f65719d070c74e210fe9965c7026fc09d2f040c Mon Sep 17 00:00:00 2001 From: Hossein Sarshar Date: Fri, 27 Sep 2024 15:43:26 -0400 Subject: [PATCH 3/6] remove debug info #7524 --- experimental/torch_xla2/test/test_ops.py | 2 -- experimental/torch_xla2/torch_xla2/ops/jaten.py | 4 ---- 2 files changed, 6 deletions(-) diff --git a/experimental/torch_xla2/test/test_ops.py b/experimental/torch_xla2/test/test_ops.py index 79bef799931..858523a0ff3 100644 --- a/experimental/torch_xla2/test/test_ops.py +++ b/experimental/torch_xla2/test/test_ops.py @@ -279,9 +279,7 @@ def replace_values_below_threshold(self, torch_tensor, threshold): @ops(ops_to_test, allowed_dtypes=(torch.float32, torch.long)) def test_reference_eager(self, device, dtype, op): sample_inputs = op.sample_inputs(device, dtype) - print(f"{op=}") for sample_input in sample_inputs: - print("sample_input: ", sample_input) t = sample_input.input if isinstance(t, torch.Tensor) and t.is_sparse: continue diff --git a/experimental/torch_xla2/torch_xla2/ops/jaten.py b/experimental/torch_xla2/torch_xla2/ops/jaten.py index 80068b49af3..27c79e0629b 100644 --- a/experimental/torch_xla2/torch_xla2/ops/jaten.py +++ b/experimental/torch_xla2/torch_xla2/ops/jaten.py @@ -14,10 +14,6 @@ from torch_xla2.ops import op_base, mappings from torch_xla2 import interop -import collections -from itertools import repeat - - # Keys are OpOverload, value is a callable that takes # XLATensor2 all_ops = {} From 3b71c6c3ab2ce1be16d006f504e52c8c66e553b7 Mon Sep 17 00:00:00 2001 From: Hossein Sarshar Date: Fri, 27 Sep 2024 15:45:41 -0400 Subject: [PATCH 4/6] removed the max_pool2d draft code --- .../torch_xla2/torch_xla2/ops/jaten.py | 27 ------------------- 1 file changed, 27 deletions(-) diff --git a/experimental/torch_xla2/torch_xla2/ops/jaten.py b/experimental/torch_xla2/torch_xla2/ops/jaten.py index 27c79e0629b..4d8c55c155c 100644 --- a/experimental/torch_xla2/torch_xla2/ops/jaten.py +++ b/experimental/torch_xla2/torch_xla2/ops/jaten.py @@ -4027,33 +4027,6 @@ def _aten__fft_c2r(self, dim, normalization, last_dim_size): s = None return jnp.fft.irfftn(self, norm=norm, axes=dim, s=s) -# @op(torch.nn.functional.max_pool2d) -# def _aten_max_pool2d(self, input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False, return_indices=False): -# print('================== Hello World [_aten_max_pool2d] =======================') -# if stride is None: -# stride = kernel_size -# -# # Handle padding modes ('SAME' or 'VALID') conversion for PyTorch-like behavior. -# if padding == 0: -# padding_mode = 'VALID' -# else: -# padding_mode = 'SAME' -# -# # JAX doesn't support dilation in reduce_window, so we ignore it. -# # JAX doesn't support ceil_mode, so we ignore it. -# -# window_shape = (kernel_size, kernel_size) -# strides = (stride, stride) -# -# # Apply max pooling using jax.lax.reduce_window -# pooled = jax.lax.reduce_window(input, -jnp.inf, jax.lax.max, -# window_dimensions=window_shape, -# window_strides=strides, -# padding=padding_mode) -# -# return pooled - - @op(torch.ops.aten.max_unpool3d) def _aten_max_unpool3d(input, indices, output_size, stride=None, padding=0): From e5278ea430bea1ee1ad63f62e08bed7ab74a7d06 Mon Sep 17 00:00:00 2001 From: Hossein Sarshar Date: Fri, 27 Sep 2024 16:05:24 -0400 Subject: [PATCH 5/6] polished the code and the un-helpful comments #7524 --- .../torch_xla2/torch_xla2/ops/jaten.py | 20 ------------------- 1 file changed, 20 deletions(-) diff --git a/experimental/torch_xla2/torch_xla2/ops/jaten.py b/experimental/torch_xla2/torch_xla2/ops/jaten.py index 4d8c55c155c..4948519cbfa 100644 --- a/experimental/torch_xla2/torch_xla2/ops/jaten.py +++ b/experimental/torch_xla2/torch_xla2/ops/jaten.py @@ -4034,23 +4034,13 @@ def _aten_max_unpool3d(input, indices, output_size, stride=None, padding=0): raise ValueError("output_size value is not set correctly. It cannot be None or empty.") output_size = [input.shape[0], input.shape[1]] + output_size - - # Initialize an output array of zeros with the provided output_size output = jnp.zeros(output_size, dtype=input.dtype) - # Use numpy.ndindex to iterate over all indices of the input tensor for idx in np.ndindex(input.shape): max_index = indices[idx] - - # Get the spatial dimensions of the output spatial_dims = output_size[2:] # (D, H, W) - - # Unravel the flat index to multi-dimensional index unpooled_spatial_idx = np.unravel_index(max_index, spatial_dims) - - # Combine batch, channel, and spatial indices full_idx = idx[:2] + unpooled_spatial_idx - output = output.at[full_idx].set(input[idx]) return output @@ -4062,23 +4052,13 @@ def _aten_max_unpool2d(input, indices, output_size, stride=None, padding=0): output_size = [input.shape[0], input.shape[1]] + output_size - # Initialize the output array with zeros output = jnp.zeros(output_size, dtype=input.dtype) - # Use numpy.ndindex to iterate over all indices of the input tensor for idx in np.ndindex(input.shape): max_index = indices[idx] - - # Get the spatial dimensions of the output (H, W) spatial_dims = output_size[2:] - - # Unravel the flat index to multi-dimensional index for 2D unpooled_spatial_idx = np.unravel_index(max_index, spatial_dims) - - # Combine batch, channel, and spatial indices full_idx = idx[:2] + unpooled_spatial_idx - - # Set the value in the output array at the corresponding location output = output.at[full_idx].set(input[idx]) return output From f338df4111664d535e1bcdc8ba804fab51879298 Mon Sep 17 00:00:00 2001 From: Hossein Sarshar Date: Fri, 27 Sep 2024 16:18:54 -0400 Subject: [PATCH 6/6] reduced the max_unpool logic into one function #7524 --- .../torch_xla2/torch_xla2/ops/jaten.py | 22 ++----------------- 1 file changed, 2 insertions(+), 20 deletions(-) diff --git a/experimental/torch_xla2/torch_xla2/ops/jaten.py b/experimental/torch_xla2/torch_xla2/ops/jaten.py index 4948519cbfa..eb4a1abc24c 100644 --- a/experimental/torch_xla2/torch_xla2/ops/jaten.py +++ b/experimental/torch_xla2/torch_xla2/ops/jaten.py @@ -4027,9 +4027,9 @@ def _aten__fft_c2r(self, dim, normalization, last_dim_size): s = None return jnp.fft.irfftn(self, norm=norm, axes=dim, s=s) - +@op(torch.ops.aten.max_unpool2d) @op(torch.ops.aten.max_unpool3d) -def _aten_max_unpool3d(input, indices, output_size, stride=None, padding=0): +def _aten_max_unpoolxd(input, indices, output_size, stride=None, padding=0): if output_size is None: raise ValueError("output_size value is not set correctly. It cannot be None or empty.") @@ -4045,21 +4045,3 @@ def _aten_max_unpool3d(input, indices, output_size, stride=None, padding=0): return output -@op(torch.ops.aten.max_unpool2d) -def _aten_max_unpool2d(input, indices, output_size, stride=None, padding=0): - if output_size is None: - raise ValueError("output_size value is not set correctly. It cannot be None or empty.") - - output_size = [input.shape[0], input.shape[1]] + output_size - - output = jnp.zeros(output_size, dtype=input.dtype) - - for idx in np.ndindex(input.shape): - max_index = indices[idx] - spatial_dims = output_size[2:] - unpooled_spatial_idx = np.unravel_index(max_index, spatial_dims) - full_idx = idx[:2] + unpooled_spatial_idx - output = output.at[full_idx].set(input[idx]) - - return output -