Skip to content

Commit

Permalink
Adds OUTPUT_PADDING to ConvTrans2D
Browse files Browse the repository at this point in the history
- Draft state.
- Unsure if correct, but a very simple and quick test gives the same
  result from pytorch.
- Note: Tensorflow result differs, both from dfdx and from pytorch.

Reference pytorch test:
```python
import torch

x = np.array([[[[0.1, 0.7], [0.3, 0.4]]]])
w = np.array([[[[-0.1, -0.3, 0.7], [0.8, -0.2, 0.1], [0.3, 0.4, -0.5]]]])

a = torch.nn.ConvTranspose2d(output_padding=0, in_channels=1, out_channels=1, kernel_size=3, stride=2, padding=1, bias = False)
b = torch.nn.ConvTranspose2d(output_padding=1, in_channels=1, out_channels=1, kernel_size=3, stride=2, padding=1, bias = False)

x = torch.from_numpy(x).float()
w0 = torch.from_numpy(w).float()

with torch.no_grad():
    a.weight = torch.nn.Parameter(w0)
    b.weight = torch.nn.Parameter(w0)

ya = a(x)
yb = b(x)

print(ya.size()) # torch.Size([1, 1, 3, 3])
print(yb.size()) # torch.Size([1, 1, 4, 4])

print(ya)
print(yb)
```
  • Loading branch information
swfsql committed Mar 1, 2024
1 parent 1175903 commit e81228c
Show file tree
Hide file tree
Showing 3 changed files with 132 additions and 55 deletions.
91 changes: 61 additions & 30 deletions dfdx-core/src/tensor_ops/convtrans2d/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ pub(super) trait ConvTrans2DKernel<E: Dtype>: Storage<E> {
) -> Result<(), Error>;
}

pub trait TryConvTrans2D<Stride, Padding, Dilation, Groups>: Sized {
pub trait TryConvTrans2D<Stride, Padding, Dilation, Groups, OutputPadding>: Sized {
type Convolved;

/// Applies a 2D convolution to the input tensor.
Expand All @@ -61,8 +61,9 @@ pub trait TryConvTrans2D<Stride, Padding, Dilation, Groups>: Sized {
padding: Padding,
dilation: Dilation,
groups: Groups,
output_padding: OutputPadding,
) -> Self::Convolved {
self.try_convtrans2d(stride, padding, dilation, groups)
self.try_convtrans2d(stride, padding, dilation, groups, output_padding)
.unwrap()
}

Expand All @@ -73,6 +74,7 @@ pub trait TryConvTrans2D<Stride, Padding, Dilation, Groups>: Sized {
padding: Padding,
dilation: Dilation,
groups: Groups,
output_padding: OutputPadding,
) -> Result<Self::Convolved, Error>;
}

Expand All @@ -82,27 +84,31 @@ impl<
const PADDING: usize,
const DILATION: usize,
Groups: Dim,
const OUTPUT_PADDING: usize,
const DIM: usize,
> TryConvTrans2D<Const<STRIDE>, Const<PADDING>, Const<DILATION>, Groups>
> TryConvTrans2D<Const<STRIDE>, Const<PADDING>, Const<DILATION>, Groups, Const<OUTPUT_PADDING>>
for (Const<DIM>, Const<KERNEL>)
where
Const<{ (DIM - 1) * STRIDE - 2 * PADDING + DILATION * (KERNEL - 1) + 1 }>: Sized,
Const<{ (DIM - 1) * STRIDE - 2 * PADDING + DILATION * (KERNEL - 1) + 1 + OUTPUT_PADDING }>:
Sized,
{
type Convolved = Const<{ (DIM - 1) * STRIDE - 2 * PADDING + DILATION * (KERNEL - 1) + 1 }>;
type Convolved =
Const<{ (DIM - 1) * STRIDE - 2 * PADDING + DILATION * (KERNEL - 1) + 1 + OUTPUT_PADDING }>;

fn try_convtrans2d(
self,
_: Const<STRIDE>,
_: Const<PADDING>,
_: Const<DILATION>,
_: Groups,
_: Const<OUTPUT_PADDING>,
) -> Result<Self::Convolved, Error> {
Ok(Const)
}
}

impl<Kernel: Dim, Stride: Dim, Padding: Dim, Dilation: Dim, Groups: Dim>
TryConvTrans2D<Stride, Padding, Dilation, Groups> for (usize, Kernel)
impl<Kernel: Dim, Stride: Dim, Padding: Dim, Dilation: Dim, Groups: Dim, OutputPadding: Dim>
TryConvTrans2D<Stride, Padding, Dilation, Groups, OutputPadding> for (usize, Kernel)
{
type Convolved = usize;

Expand All @@ -112,18 +118,33 @@ impl<Kernel: Dim, Stride: Dim, Padding: Dim, Dilation: Dim, Groups: Dim>
padding: Padding,
dilation: Dilation,
_: Groups,
output_padding: OutputPadding,
) -> Result<Self::Convolved, Error> {
let (dim, kernel) = self;
Ok(
((dim - 1) * stride.size() + dilation.size() * (kernel.size() - 1) + 1)
.checked_sub(2 * padding.size())
.unwrap(),
)
Ok(((dim - 1) * stride.size()
+ dilation.size() * (kernel.size() - 1)
+ 1
+ output_padding.size())
.checked_sub(2 * padding.size())
.unwrap())
}
}

impl<InpChan, OutChanOverGroups, Kernel, Stride, Padding, Dilation, Groups, H, W, E, D, T>
TryConvTrans2D<Stride, Padding, Dilation, Groups>
impl<
InpChan,
OutChanOverGroups,
Kernel,
Stride,
Padding,
Dilation,
Groups,
OutputPadding,
H,
W,
E,
D,
T,
> TryConvTrans2D<Stride, Padding, Dilation, Groups, OutputPadding>
for (
Tensor<(InpChan, H, W), E, D, T>,
Tensor<(InpChan, OutChanOverGroups, Kernel, Kernel), E, D>,
Expand All @@ -136,23 +157,26 @@ where
Padding: Dim,
Dilation: Dim,
Groups: Dim,
OutputPadding: Dim,
H: Dim,
W: Dim,
E: Dtype,
D: ConvTrans2DKernel<E> + crate::tensor_ops::reshape_to::ReshapeKernel<E>,
T: Tape<E, D>,
OutChanOverGroups: std::ops::Mul<Groups>,
<OutChanOverGroups as std::ops::Mul<Groups>>::Output: Dim,
(H, Kernel): TryConvTrans2D<Stride, Padding, Dilation, Groups>,
(W, Kernel): TryConvTrans2D<Stride, Padding, Dilation, Groups>,
<(H, Kernel) as TryConvTrans2D<Stride, Padding, Dilation, Groups>>::Convolved: Dim,
<(W, Kernel) as TryConvTrans2D<Stride, Padding, Dilation, Groups>>::Convolved: Dim,
(H, Kernel): TryConvTrans2D<Stride, Padding, Dilation, Groups, OutputPadding>,
(W, Kernel): TryConvTrans2D<Stride, Padding, Dilation, Groups, OutputPadding>,
<(H, Kernel) as TryConvTrans2D<Stride, Padding, Dilation, Groups, OutputPadding>>::Convolved:
Dim,
<(W, Kernel) as TryConvTrans2D<Stride, Padding, Dilation, Groups, OutputPadding>>::Convolved:
Dim,
{
type Convolved = Tensor<
(
<OutChanOverGroups as std::ops::Mul<Groups>>::Output,
<(H, Kernel) as TryConvTrans2D<Stride, Padding, Dilation, Groups>>::Convolved,
<(W, Kernel) as TryConvTrans2D<Stride, Padding, Dilation, Groups>>::Convolved,
<(H, Kernel) as TryConvTrans2D<Stride, Padding, Dilation, Groups, OutputPadding>>::Convolved,
<(W, Kernel) as TryConvTrans2D<Stride, Padding, Dilation, Groups, OutputPadding>>::Convolved,
),
E,
D,
Expand All @@ -165,11 +189,13 @@ where
padding: Padding,
dilation: Dilation,
groups: Groups,
output_padding: OutputPadding,
) -> Result<Self::Convolved, Error> {
let (img, filters) = self;
let (inp_chan, h, w) = img.shape;
let img = img.try_reshape_like(&(Const::<1>, inp_chan, h, w))?;
let out = (img, filters).try_convtrans2d(stride, padding, dilation, groups)?;
let out =
(img, filters).try_convtrans2d(stride, padding, dilation, groups, output_padding)?;
let (_, out_chan, out_h, out_w) = out.shape;
out.try_reshape_like(&(out_chan, out_h, out_w))
}
Expand All @@ -182,13 +208,14 @@ impl<
Padding,
Dilation,
Groups,
OutputPadding,
Batch,
H,
W,
E,
D,
T,
> TryConvTrans2D<Stride, Padding, Dilation, Groups>
> TryConvTrans2D<Stride, Padding, Dilation, Groups, OutputPadding>
for (
Tensor<(Batch, InpChan, H, W), E, D, T>,
Tensor<(InpChan, OutChanOverGroups, Kernel, Kernel), E, D>,
Expand All @@ -201,6 +228,7 @@ where
Padding: Dim,
Dilation: Dim,
Groups: Dim,
OutputPadding: Dim,
Batch: Dim,
H: Dim,
W: Dim,
Expand All @@ -209,17 +237,19 @@ where
T: Tape<E, D>,
OutChanOverGroups: std::ops::Mul<Groups>,
<OutChanOverGroups as std::ops::Mul<Groups>>::Output: Dim,
(H, Kernel): TryConvTrans2D<Stride, Padding, Dilation, Groups>,
(W, Kernel): TryConvTrans2D<Stride, Padding, Dilation, Groups>,
<(H, Kernel) as TryConvTrans2D<Stride, Padding, Dilation, Groups>>::Convolved: Dim,
<(W, Kernel) as TryConvTrans2D<Stride, Padding, Dilation, Groups>>::Convolved: Dim,
(H, Kernel): TryConvTrans2D<Stride, Padding, Dilation, Groups, OutputPadding>,
(W, Kernel): TryConvTrans2D<Stride, Padding, Dilation, Groups, OutputPadding>,
<(H, Kernel) as TryConvTrans2D<Stride, Padding, Dilation, Groups, OutputPadding>>::Convolved:
Dim,
<(W, Kernel) as TryConvTrans2D<Stride, Padding, Dilation, Groups, OutputPadding>>::Convolved:
Dim,
{
type Convolved = Tensor<
(
Batch,
<OutChanOverGroups as std::ops::Mul<Groups>>::Output,
<(H, Kernel) as TryConvTrans2D<Stride, Padding, Dilation, Groups>>::Convolved,
<(W, Kernel) as TryConvTrans2D<Stride, Padding, Dilation, Groups>>::Convolved,
<(H, Kernel) as TryConvTrans2D<Stride, Padding, Dilation, Groups, OutputPadding>>::Convolved,
<(W, Kernel) as TryConvTrans2D<Stride, Padding, Dilation, Groups, OutputPadding>>::Convolved,
),
E,
D,
Expand All @@ -232,6 +262,7 @@ where
padding: Padding,
dilation: Dilation,
groups: Groups,
output_padding: OutputPadding,
) -> Result<Self::Convolved, Error> {
let (img, filters) = self;
assert_eq!(img.shape.1, filters.shape.0);
Expand All @@ -242,8 +273,8 @@ where
if img.strides != img.shape.strides() || filters.strides != filters.shape.strides() {
panic!("Image & filter inputs to conv2d must be contiguous");
}
let h_out = (h, kernel).convtrans2d(stride, padding, dilation, groups);
let w_out = (w, kernel).convtrans2d(stride, padding, dilation, groups);
let h_out = (h, kernel).convtrans2d(stride, padding, dilation, groups, output_padding);
let w_out = (w, kernel).convtrans2d(stride, padding, dilation, groups, output_padding);
let op = ConvTrans2DOp {
stride: stride.size(),
padding: padding.size(),
Expand Down
28 changes: 14 additions & 14 deletions dfdx-core/src/tensor_ops/convtrans2d/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ fn test_convtrans2d_default() {
],
])
.to_dtype::<TestDtype>();
let y =
(x.leaky_trace(), w.clone()).convtrans2d(Const::<1>, Const::<0>, Const::<1>, Const::<1>);
let y = (x.leaky_trace(), w.clone())
.convtrans2d(Const::<1>, Const::<0>, Const::<1>, Const::<1>, Const::<0>);
#[rustfmt::skip]
assert_close_to_literal!(
y,
Expand Down Expand Up @@ -125,8 +125,8 @@ fn test_convtrans2d_stride_2() {
],
])
.to_dtype::<TestDtype>();
let y =
(x.leaky_trace(), w.clone()).convtrans2d(Const::<2>, Const::<0>, Const::<1>, Const::<1>);
let y = (x.leaky_trace(), w.clone())
.convtrans2d(Const::<2>, Const::<0>, Const::<1>, Const::<1>, Const::<0>);
#[rustfmt::skip]
assert_close_to_literal!(
y,
Expand Down Expand Up @@ -223,8 +223,8 @@ fn test_convtrans2d_padded() {
],
])
.to_dtype::<TestDtype>();
let y =
(x.leaky_trace(), w.clone()).convtrans2d(Const::<1>, Const::<1>, Const::<1>, Const::<1>);
let y = (x.leaky_trace(), w.clone())
.convtrans2d(Const::<1>, Const::<1>, Const::<1>, Const::<1>, Const::<0>);
assert_close_to_literal!(
y,
[
Expand Down Expand Up @@ -283,8 +283,8 @@ fn test_convtrans2d_batched() {
let x: Tensor<Rank3<3, 28, 28>, TestDtype, _> = dev.sample_normal();
let w: Tensor<Rank4<3, 5, 6, 6>, TestDtype, _> = dev.sample_normal();

let y: Tensor<Rank3<5, 83, 83>, _, _, _> =
(x.leaky_trace(), w.clone()).convtrans2d(Const::<3>, Const::<2>, Const::<1>, Const::<1>);
let y: Tensor<Rank3<5, 83, 83>, _, _, _> = (x.leaky_trace(), w.clone())
.convtrans2d(Const::<3>, Const::<2>, Const::<1>, Const::<1>, Const::<0>);
let y0 = y.retaped::<NoneTape>();
let grads0 = y.square().mean().backward();
let x0 = grads0.get(&x);
Expand All @@ -294,8 +294,8 @@ fn test_convtrans2d_batched() {
.broadcast::<Rank4<10, 3, 28, 28>, _>()
.reshape::<Rank4<10, 3, 28, 28>>();

let y: Tensor<Rank4<10, 5, 83, 83>, _, _, _> =
(x.leaky_trace(), w.clone()).convtrans2d(Const::<3>, Const::<2>, Const::<1>, Const::<1>);
let y: Tensor<Rank4<10, 5, 83, 83>, _, _, _> = (x.leaky_trace(), w.clone())
.convtrans2d(Const::<3>, Const::<2>, Const::<1>, Const::<1>, Const::<0>);
for i in 0..10 {
assert_close_to_tensor!(y0, y.retaped::<NoneTape>().select(dev.tensor(i)), 1e-5);
}
Expand Down Expand Up @@ -341,8 +341,8 @@ fn test_convtrans2d_grouped() {
],
])
.to_dtype::<TestDtype>();
let y =
(x.leaky_trace(), w.clone()).convtrans2d(Const::<1>, Const::<0>, Const::<1>, Const::<2>);
let y = (x.leaky_trace(), w.clone())
.convtrans2d(Const::<1>, Const::<0>, Const::<1>, Const::<2>, Const::<0>);
#[rustfmt::skip]
assert_close_to_literal!(
y,
Expand Down Expand Up @@ -451,8 +451,8 @@ fn test_convtrans2d_dilated() {
],
])
.to_dtype::<TestDtype>();
let y =
(x.leaky_trace(), w.clone()).convtrans2d(Const::<1>, Const::<0>, Const::<2>, Const::<1>);
let y = (x.leaky_trace(), w.clone())
.convtrans2d(Const::<1>, Const::<0>, Const::<2>, Const::<1>, Const::<0>);
#[rustfmt::skip]
assert_close_to_literal!(
y,
Expand Down
Loading

0 comments on commit e81228c

Please sign in to comment.