Skip to content

Commit

Permalink
[torch_xla2] Support aten.avg_pool2d and aten.avg_pool3d with `ce…
Browse files Browse the repository at this point in the history
…il_mode=True` (#7806)
  • Loading branch information
chunnienc committed Aug 5, 2024
1 parent 6250c02 commit dd04c58
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 7 deletions.
13 changes: 13 additions & 0 deletions experimental/torch_xla2/test/test_core_aten_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,19 @@ def test_aten__adaptive_avg_pool2d_1(self):
kwargs = dict()
run_export_and_compare(self, torch.ops.aten._adaptive_avg_pool2d, args,
kwargs)

def test_aten_avg_pool2d_2(self):
args = (
torch.randn((1, 3, 6, 6)).to(torch.float32),
[3, 3],
[1, 1],
[1, 1],
True,
True,
None,
)
kwargs = dict()
run_export_and_compare(self, torch.ops.aten.avg_pool2d, args, kwargs)

def test_aten_squeeze_dim_0(self):
args = (
Expand Down
59 changes: 52 additions & 7 deletions experimental/torch_xla2/torch_xla2/ops/jaten.py
Original file line number Diff line number Diff line change
Expand Up @@ -1387,6 +1387,40 @@ def adaptive_kernel_size(input_shape, output_shape):
return y


def _ceil_mode_padding(
padding: list[int],
input_shape: list[int],
kernel_size: list[int],
stride: list[int],
ceil_mode: bool,
):
"""Creates low and high padding specification for the given padding (which is symmetric) and ceil mode.
Additional high padding could be required when ceil mode is set.
"""
ceil_mode_padding = []
for i in range(len(padding)):
left_padding = padding[i]
right_padding = left_padding

input_size = input_shape[2 + i]
output_size_rem = (input_size + 2 * left_padding -
kernel_size[i]) % stride[i]
if ceil_mode and output_size_rem != 0:
extra_padding = stride[i] - output_size_rem
new_output_size = (input_size + left_padding + right_padding +
extra_padding - kernel_size[i] + stride[i] -
1) // stride[i] + 1
# Ensure that the last pooling starts inside the image.
size_to_compare = input_size + left_padding

if (new_output_size - 1) * stride[i] < size_to_compare:
right_padding += extra_padding

ceil_mode_padding.append((left_padding, right_padding))
return ceil_mode_padding


# aten.avg_pool2d
@op(torch.ops.aten.avg_pool2d)
@op(torch.ops.aten.avg_pool3d)
Expand All @@ -1401,23 +1435,34 @@ def _aten_avg_pool(
):
num_batch_dims = len(inputs.shape) - len(kernel_size) - 1
kernel_size = tuple(kernel_size)
strides = tuple(strides)
strides = tuple(strides) if strides else kernel_size
if isinstance(padding, int):
padding = tuple((padding, padding) for _ in range(len(kernel_size)))
elif isinstance(padding, list):
padding = tuple((p, p) for p in padding)
padding = [padding for _ in range(len(kernel_size))]

input_shape = inputs.shape
if num_batch_dims == 0:
input_shape = [1, *input_shape]
padding = _ceil_mode_padding(padding, input_shape, kernel_size, strides,
ceil_mode)

y = pool(inputs, 0.0, jax.lax.add, kernel_size, strides, padding)
if count_include_pad:
y = y / np.prod(kernel_size)
if divisor_override is not None:
y = y / jnp.array(divisor_override, y.dtype)
elif count_include_pad:
y = y / jnp.array(np.prod(kernel_size), y.dtype)
else:
div_shape = list(inputs.shape)
div_shape[num_batch_dims] = 1
div_shape = tuple(div_shape)
if len(div_shape) - 2 == len(kernel_size):
div_shape = (1,) + div_shape[1:]
y = y / pool(
jnp.ones(div_shape), 0.0, jax.lax.add, kernel_size, strides, padding
jnp.ones(div_shape, y.dtype),
jnp.array(0.0, y.dtype),
jax.lax.add,
kernel_size,
strides,
padding,
)
return y

Expand Down

0 comments on commit dd04c58

Please sign in to comment.