diff --git a/experimental/torch_xla2/test/test_ops.py b/experimental/torch_xla2/test/test_ops.py index 0e30d181c0d..8ddccccdcaf 100644 --- a/experimental/torch_xla2/test/test_ops.py +++ b/experimental/torch_xla2/test/test_ops.py @@ -93,8 +93,6 @@ "nn.functional.avg_pool2d", "nn.functional.avg_pool3d", "nn.functional.bilinear", - "nn.functional.conv2d", - "nn.functional.conv3d", "nn.functional.conv_transpose1d", "nn.functional.conv_transpose2d", "nn.functional.conv_transpose3d", diff --git a/experimental/torch_xla2/torch_xla2/ops/jaten.py b/experimental/torch_xla2/torch_xla2/ops/jaten.py index 291587329bb..54adbd30e65 100644 --- a/experimental/torch_xla2/torch_xla2/ops/jaten.py +++ b/experimental/torch_xla2/torch_xla2/ops/jaten.py @@ -799,7 +799,10 @@ def _aten_convolution( if transposed: raise NotImplementedError("Transposed convolution is not implemented.") - def make_padding(padding): + def make_padding(padding, num_spatial_dims): + # Expand single padding to pairs expected by jax + if len(padding) == 1 and len(padding) < num_spatial_dims: + padding *= num_spatial_dims return ((p, p) for p in padding) def create_default_conv_dimension_numbers(num_spatial_dims): @@ -822,7 +825,7 @@ def create_default_conv_dimension_numbers(num_spatial_dims): input, weight, stride, - make_padding(padding), + make_padding(padding, len(stride)), lhs_dilation=(1,) * len(stride), rhs_dilation=dilation, dimension_numbers=create_default_conv_dimension_numbers(len(stride)),