From e81228c300a8a48c4e257bdaeb71c46fcc8b18be Mon Sep 17 00:00:00 2001 From: Thiago Machado Date: Thu, 16 Nov 2023 00:23:19 -0500 Subject: [PATCH] Adds `OUTPUT_PADDING` to `ConvTrans2D` - Draft state. - Unsure if correct, but a very simple and quick test gives the same result from pytorch. - Note: Tensorflow result differs, both from dfdx and from pytorch. Reference pytorch test: ```python import torch x = np.array([[[[0.1, 0.7], [0.3, 0.4]]]]) w = np.array([[[[-0.1, -0.3, 0.7], [0.8, -0.2, 0.1], [0.3, 0.4, -0.5]]]]) a = torch.nn.ConvTranspose2d(output_padding=0, in_channels=1, out_channels=1, kernel_size=3, stride=2, padding=1, bias = False) b = torch.nn.ConvTranspose2d(output_padding=1, in_channels=1, out_channels=1, kernel_size=3, stride=2, padding=1, bias = False) x = torch.from_numpy(x).float() w0 = torch.from_numpy(w).float() with torch.no_grad(): a.weight = torch.nn.Parameter(w0) b.weight = torch.nn.Parameter(w0) ya = a(x) yb = b(x) print(ya.size()) # torch.Size([1, 1, 3, 3]) print(yb.size()) # torch.Size([1, 1, 4, 4]) print(ya) print(yb) ``` --- dfdx-core/src/tensor_ops/convtrans2d/mod.rs | 91 +++++++++++++------ dfdx-core/src/tensor_ops/convtrans2d/tests.rs | 28 +++--- dfdx/src/nn/layers/conv_trans2d.rs | 68 +++++++++++--- 3 files changed, 132 insertions(+), 55 deletions(-) diff --git a/dfdx-core/src/tensor_ops/convtrans2d/mod.rs b/dfdx-core/src/tensor_ops/convtrans2d/mod.rs index 761ab4915..b26456208 100644 --- a/dfdx-core/src/tensor_ops/convtrans2d/mod.rs +++ b/dfdx-core/src/tensor_ops/convtrans2d/mod.rs @@ -51,7 +51,7 @@ pub(super) trait ConvTrans2DKernel: Storage { ) -> Result<(), Error>; } -pub trait TryConvTrans2D: Sized { +pub trait TryConvTrans2D: Sized { type Convolved; /// Applies a 2D convolution to the input tensor. @@ -61,8 +61,9 @@ pub trait TryConvTrans2D: Sized { padding: Padding, dilation: Dilation, groups: Groups, + output_padding: OutputPadding, ) -> Self::Convolved { - self.try_convtrans2d(stride, padding, dilation, groups) + self.try_convtrans2d(stride, padding, dilation, groups, output_padding) .unwrap() } @@ -73,6 +74,7 @@ pub trait TryConvTrans2D: Sized { padding: Padding, dilation: Dilation, groups: Groups, + output_padding: OutputPadding, ) -> Result; } @@ -82,13 +84,16 @@ impl< const PADDING: usize, const DILATION: usize, Groups: Dim, + const OUTPUT_PADDING: usize, const DIM: usize, - > TryConvTrans2D, Const, Const, Groups> + > TryConvTrans2D, Const, Const, Groups, Const> for (Const, Const) where - Const<{ (DIM - 1) * STRIDE - 2 * PADDING + DILATION * (KERNEL - 1) + 1 }>: Sized, + Const<{ (DIM - 1) * STRIDE - 2 * PADDING + DILATION * (KERNEL - 1) + 1 + OUTPUT_PADDING }>: + Sized, { - type Convolved = Const<{ (DIM - 1) * STRIDE - 2 * PADDING + DILATION * (KERNEL - 1) + 1 }>; + type Convolved = + Const<{ (DIM - 1) * STRIDE - 2 * PADDING + DILATION * (KERNEL - 1) + 1 + OUTPUT_PADDING }>; fn try_convtrans2d( self, @@ -96,13 +101,14 @@ where _: Const, _: Const, _: Groups, + _: Const, ) -> Result { Ok(Const) } } -impl - TryConvTrans2D for (usize, Kernel) +impl + TryConvTrans2D for (usize, Kernel) { type Convolved = usize; @@ -112,18 +118,33 @@ impl padding: Padding, dilation: Dilation, _: Groups, + output_padding: OutputPadding, ) -> Result { let (dim, kernel) = self; - Ok( - ((dim - 1) * stride.size() + dilation.size() * (kernel.size() - 1) + 1) - .checked_sub(2 * padding.size()) - .unwrap(), - ) + Ok(((dim - 1) * stride.size() + + dilation.size() * (kernel.size() - 1) + + 1 + + output_padding.size()) + .checked_sub(2 * padding.size()) + .unwrap()) } } -impl - TryConvTrans2D +impl< + InpChan, + OutChanOverGroups, + Kernel, + Stride, + Padding, + Dilation, + Groups, + OutputPadding, + H, + W, + E, + D, + T, + > TryConvTrans2D for ( Tensor<(InpChan, H, W), E, D, T>, Tensor<(InpChan, OutChanOverGroups, Kernel, Kernel), E, D>, @@ -136,6 +157,7 @@ where Padding: Dim, Dilation: Dim, Groups: Dim, + OutputPadding: Dim, H: Dim, W: Dim, E: Dtype, @@ -143,16 +165,18 @@ where T: Tape, OutChanOverGroups: std::ops::Mul, >::Output: Dim, - (H, Kernel): TryConvTrans2D, - (W, Kernel): TryConvTrans2D, - <(H, Kernel) as TryConvTrans2D>::Convolved: Dim, - <(W, Kernel) as TryConvTrans2D>::Convolved: Dim, + (H, Kernel): TryConvTrans2D, + (W, Kernel): TryConvTrans2D, + <(H, Kernel) as TryConvTrans2D>::Convolved: + Dim, + <(W, Kernel) as TryConvTrans2D>::Convolved: + Dim, { type Convolved = Tensor< ( >::Output, - <(H, Kernel) as TryConvTrans2D>::Convolved, - <(W, Kernel) as TryConvTrans2D>::Convolved, + <(H, Kernel) as TryConvTrans2D>::Convolved, + <(W, Kernel) as TryConvTrans2D>::Convolved, ), E, D, @@ -165,11 +189,13 @@ where padding: Padding, dilation: Dilation, groups: Groups, + output_padding: OutputPadding, ) -> Result { let (img, filters) = self; let (inp_chan, h, w) = img.shape; let img = img.try_reshape_like(&(Const::<1>, inp_chan, h, w))?; - let out = (img, filters).try_convtrans2d(stride, padding, dilation, groups)?; + let out = + (img, filters).try_convtrans2d(stride, padding, dilation, groups, output_padding)?; let (_, out_chan, out_h, out_w) = out.shape; out.try_reshape_like(&(out_chan, out_h, out_w)) } @@ -182,13 +208,14 @@ impl< Padding, Dilation, Groups, + OutputPadding, Batch, H, W, E, D, T, - > TryConvTrans2D + > TryConvTrans2D for ( Tensor<(Batch, InpChan, H, W), E, D, T>, Tensor<(InpChan, OutChanOverGroups, Kernel, Kernel), E, D>, @@ -201,6 +228,7 @@ where Padding: Dim, Dilation: Dim, Groups: Dim, + OutputPadding: Dim, Batch: Dim, H: Dim, W: Dim, @@ -209,17 +237,19 @@ where T: Tape, OutChanOverGroups: std::ops::Mul, >::Output: Dim, - (H, Kernel): TryConvTrans2D, - (W, Kernel): TryConvTrans2D, - <(H, Kernel) as TryConvTrans2D>::Convolved: Dim, - <(W, Kernel) as TryConvTrans2D>::Convolved: Dim, + (H, Kernel): TryConvTrans2D, + (W, Kernel): TryConvTrans2D, + <(H, Kernel) as TryConvTrans2D>::Convolved: + Dim, + <(W, Kernel) as TryConvTrans2D>::Convolved: + Dim, { type Convolved = Tensor< ( Batch, >::Output, - <(H, Kernel) as TryConvTrans2D>::Convolved, - <(W, Kernel) as TryConvTrans2D>::Convolved, + <(H, Kernel) as TryConvTrans2D>::Convolved, + <(W, Kernel) as TryConvTrans2D>::Convolved, ), E, D, @@ -232,6 +262,7 @@ where padding: Padding, dilation: Dilation, groups: Groups, + output_padding: OutputPadding, ) -> Result { let (img, filters) = self; assert_eq!(img.shape.1, filters.shape.0); @@ -242,8 +273,8 @@ where if img.strides != img.shape.strides() || filters.strides != filters.shape.strides() { panic!("Image & filter inputs to conv2d must be contiguous"); } - let h_out = (h, kernel).convtrans2d(stride, padding, dilation, groups); - let w_out = (w, kernel).convtrans2d(stride, padding, dilation, groups); + let h_out = (h, kernel).convtrans2d(stride, padding, dilation, groups, output_padding); + let w_out = (w, kernel).convtrans2d(stride, padding, dilation, groups, output_padding); let op = ConvTrans2DOp { stride: stride.size(), padding: padding.size(), diff --git a/dfdx-core/src/tensor_ops/convtrans2d/tests.rs b/dfdx-core/src/tensor_ops/convtrans2d/tests.rs index 3d64acbf0..16002b7c7 100644 --- a/dfdx-core/src/tensor_ops/convtrans2d/tests.rs +++ b/dfdx-core/src/tensor_ops/convtrans2d/tests.rs @@ -33,8 +33,8 @@ fn test_convtrans2d_default() { ], ]) .to_dtype::(); - let y = - (x.leaky_trace(), w.clone()).convtrans2d(Const::<1>, Const::<0>, Const::<1>, Const::<1>); + let y = (x.leaky_trace(), w.clone()) + .convtrans2d(Const::<1>, Const::<0>, Const::<1>, Const::<1>, Const::<0>); #[rustfmt::skip] assert_close_to_literal!( y, @@ -125,8 +125,8 @@ fn test_convtrans2d_stride_2() { ], ]) .to_dtype::(); - let y = - (x.leaky_trace(), w.clone()).convtrans2d(Const::<2>, Const::<0>, Const::<1>, Const::<1>); + let y = (x.leaky_trace(), w.clone()) + .convtrans2d(Const::<2>, Const::<0>, Const::<1>, Const::<1>, Const::<0>); #[rustfmt::skip] assert_close_to_literal!( y, @@ -223,8 +223,8 @@ fn test_convtrans2d_padded() { ], ]) .to_dtype::(); - let y = - (x.leaky_trace(), w.clone()).convtrans2d(Const::<1>, Const::<1>, Const::<1>, Const::<1>); + let y = (x.leaky_trace(), w.clone()) + .convtrans2d(Const::<1>, Const::<1>, Const::<1>, Const::<1>, Const::<0>); assert_close_to_literal!( y, [ @@ -283,8 +283,8 @@ fn test_convtrans2d_batched() { let x: Tensor, TestDtype, _> = dev.sample_normal(); let w: Tensor, TestDtype, _> = dev.sample_normal(); - let y: Tensor, _, _, _> = - (x.leaky_trace(), w.clone()).convtrans2d(Const::<3>, Const::<2>, Const::<1>, Const::<1>); + let y: Tensor, _, _, _> = (x.leaky_trace(), w.clone()) + .convtrans2d(Const::<3>, Const::<2>, Const::<1>, Const::<1>, Const::<0>); let y0 = y.retaped::(); let grads0 = y.square().mean().backward(); let x0 = grads0.get(&x); @@ -294,8 +294,8 @@ fn test_convtrans2d_batched() { .broadcast::, _>() .reshape::>(); - let y: Tensor, _, _, _> = - (x.leaky_trace(), w.clone()).convtrans2d(Const::<3>, Const::<2>, Const::<1>, Const::<1>); + let y: Tensor, _, _, _> = (x.leaky_trace(), w.clone()) + .convtrans2d(Const::<3>, Const::<2>, Const::<1>, Const::<1>, Const::<0>); for i in 0..10 { assert_close_to_tensor!(y0, y.retaped::().select(dev.tensor(i)), 1e-5); } @@ -341,8 +341,8 @@ fn test_convtrans2d_grouped() { ], ]) .to_dtype::(); - let y = - (x.leaky_trace(), w.clone()).convtrans2d(Const::<1>, Const::<0>, Const::<1>, Const::<2>); + let y = (x.leaky_trace(), w.clone()) + .convtrans2d(Const::<1>, Const::<0>, Const::<1>, Const::<2>, Const::<0>); #[rustfmt::skip] assert_close_to_literal!( y, @@ -451,8 +451,8 @@ fn test_convtrans2d_dilated() { ], ]) .to_dtype::(); - let y = - (x.leaky_trace(), w.clone()).convtrans2d(Const::<1>, Const::<0>, Const::<2>, Const::<1>); + let y = (x.leaky_trace(), w.clone()) + .convtrans2d(Const::<1>, Const::<0>, Const::<2>, Const::<1>, Const::<0>); #[rustfmt::skip] assert_close_to_literal!( y, diff --git a/dfdx/src/nn/layers/conv_trans2d.rs b/dfdx/src/nn/layers/conv_trans2d.rs index b76836767..092f9e65b 100644 --- a/dfdx/src/nn/layers/conv_trans2d.rs +++ b/dfdx/src/nn/layers/conv_trans2d.rs @@ -15,6 +15,7 @@ use crate::prelude::*; /// - `Dilation`: Controls the spacing between kernel points. Defaults to `Const<1>`. /// - `Groups`: Controls the connections between inputs and outputs. Defaults to `Const<1>`. /// `InChan` and `OutChan` must both be divisible by `Groups`. +/// - `OutputPadding`: Controls the additional size added to one side of the output shape. Defaults to `Const<0>`. #[derive(Debug, Default, Clone, Copy)] pub struct ConvTrans2DConfig< InChan: Dim, @@ -24,6 +25,7 @@ pub struct ConvTrans2DConfig< Padding: Dim = Const<0>, Dilation: Dim = Const<1>, Groups: Dim = Const<1>, + OutputPadding: Dim = Const<0>, > { pub in_chan: InChan, pub out_chan: OutChan, @@ -32,6 +34,7 @@ pub struct ConvTrans2DConfig< pub padding: Padding, pub dilation: Dilation, pub groups: Groups, + pub output_padding: OutputPadding, } /// Compile time sugar alias around [ConvTrans2DConfig]. @@ -43,6 +46,7 @@ pub type ConvTrans2DConstConfig< const PADDING: usize = 0, const DILATION: usize = 1, const GROUPS: usize = 1, + const OUTPUT_PADDING: usize = 0, > = ConvTrans2DConfig< Const, Const, @@ -51,18 +55,20 @@ pub type ConvTrans2DConstConfig< Const, Const, Const, + Const, >; -impl> - BuildOnDevice for ConvTrans2DConfig +impl> + BuildOnDevice for ConvTrans2DConfig where O: std::ops::Div, >::Output: Dim, { - type Built = ConvTrans2D; + type Built = ConvTrans2D; fn try_build_on_device(&self, device: &D) -> Result { assert_eq!(self.in_chan.size() % self.groups.size(), 0); assert_eq!(self.out_chan.size() % self.groups.size(), 0); + assert!(self.output_padding.size() < self.stride.size()); let o_over_g = self.out_chan / self.groups; let weight = device.try_zeros_like(&(self.in_chan, o_over_g, self.kernel_size, self.kernel_size))?; @@ -72,6 +78,7 @@ where padding: self.padding, dilation: self.dilation, groups: self.groups, + output_padding: self.output_padding, }) } } @@ -79,8 +86,18 @@ where /// See [ConvTrans2DConfig]. #[derive(Debug, Clone, UpdateParams, ZeroGrads)] #[cfg_attr(feature = "safetensors", derive(SaveSafeTensors, LoadSafeTensors))] -pub struct ConvTrans2D -where +pub struct ConvTrans2D< + InChan, + OutChan, + KernelSize, + Stride, + Padding, + Dilation, + Groups, + OutputPadding, + Elem, + Dev, +> where OutChan: std::ops::Div, >::Output: Dim, InChan: Dim, @@ -90,6 +107,7 @@ where Padding: Dim, Dilation: Dim, Groups: Dim, + OutputPadding: Dim, Elem: Dtype, Dev: Device, { @@ -110,10 +128,11 @@ where pub padding: Padding, pub dilation: Dilation, pub groups: Groups, + pub output_padding: OutputPadding, } -impl ResetParams - for ConvTrans2D +impl ResetParams + for ConvTrans2D where O: std::ops::Div, >::Output: Dim, @@ -129,8 +148,8 @@ where } } -impl Module - for ConvTrans2D +impl Module + for ConvTrans2D where O: std::ops::Div, >::Output: Dim, @@ -139,18 +158,19 @@ where ( Img, Tensor<(I, >::Output, K, K), E, D>, - ): TryConvTrans2D, + ): TryConvTrans2D, { type Output = <( Img, Tensor<(I, >::Output, K, K), E, D>, - ) as TryConvTrans2D>::Convolved; + ) as TryConvTrans2D>::Convolved; fn try_forward(&self, x: Img) -> Result { (x, self.weight.clone()).try_convtrans2d( self.stride, self.padding, self.dilation, self.groups, + self.output_padding, ) } } @@ -237,4 +257,30 @@ mod tests { assert_ne!(weight_init.array(), m.weight.array()); } + + #[rustfmt::skip] + #[test] + fn test_forward_output_padding() { + let dev: TestDevice = Default::default(); + let x = dev.tensor([[[[0.1, 0.7], [0.3, 0.4]]]]); + let w = dev.tensor([[[[-0.1, -0.3, 0.7], [0.8, -0.2, 0.1], [0.3, 0.4, -0.5]]]]); + let mut m = dev + .build_module::(>::default()); + m.weight = w.clone(); + let y: Tensor, _, _, _> = m.forward(x.clone()); + assert_close_to_literal!(y,[[[[-0.02, 0.57, -0.14], [-0.05, 0.33, 0.16,], [-0.06, 0.35000002, -0.08000001]]]]); + + let mut m = dev + .build_module::(>::default()); + m.weight = w.clone(); + let y: Tensor, _, _, _> = m.forward(x.clone()); + assert_close_to_literal!( + y, [[[ + [-0.0200, 0.5700, -0.1400, 0.0700], + [-0.0500, 0.3300, 0.1600, -0.0700], + [-0.0600, 0.3500, -0.0800, 0.0400], + [0.1200, -0.0300, 0.1600, -0.2000], + ]]] + ); + } }