Skip to content

Commit

Permalink
[torch_xla2] Fix nn.functional.conv2d and conv3d (#8048)
Browse files Browse the repository at this point in the history
  • Loading branch information
dvhg committed Sep 23, 2024
1 parent 2aea97a commit 6af550a
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 4 deletions.
2 changes: 0 additions & 2 deletions experimental/torch_xla2/test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,6 @@
"nn.functional.avg_pool2d",
"nn.functional.avg_pool3d",
"nn.functional.bilinear",
"nn.functional.conv2d",
"nn.functional.conv3d",
"nn.functional.conv_transpose1d",
"nn.functional.conv_transpose2d",
"nn.functional.conv_transpose3d",
Expand Down
7 changes: 5 additions & 2 deletions experimental/torch_xla2/torch_xla2/ops/jaten.py
Original file line number Diff line number Diff line change
Expand Up @@ -799,7 +799,10 @@ def _aten_convolution(
if transposed:
raise NotImplementedError("Transposed convolution is not implemented.")

def make_padding(padding):
def make_padding(padding, num_spatial_dims):
# Expand single padding to pairs expected by jax
if len(padding) == 1 and len(padding) < num_spatial_dims:
padding *= num_spatial_dims
return ((p, p) for p in padding)

def create_default_conv_dimension_numbers(num_spatial_dims):
Expand All @@ -822,7 +825,7 @@ def create_default_conv_dimension_numbers(num_spatial_dims):
input,
weight,
stride,
make_padding(padding),
make_padding(padding, len(stride)),
lhs_dilation=(1,) * len(stride),
rhs_dilation=dilation,
dimension_numbers=create_default_conv_dimension_numbers(len(stride)),
Expand Down

0 comments on commit 6af550a

Please sign in to comment.