diff --git a/Cargo.toml b/Cargo.toml index fb1c294b7..127efeeb9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -29,15 +29,13 @@ no-std-compat = { version = "0.4.1", default-features = false, features = [ "all spin = { version = "0.9.8", default-features = false, features = ["spin_mutex", "rwlock", "portable_atomic"], optional = true } rand = { version = "0.8.5", default-features = false, features = ["std_rng"] } rand_distr = { version = "0.4.3", default-features = false } -zip = { version = "0.6.2", default-features = false, optional = true } -cblas-sys = { version = "0.1.4", default-features = false, optional = true } -libc = { version = "0.2", default-features = false, optional = true } -cudarc = { git = "https://github.com/coreylowman/cudarc", branch = "dfdx-half", default-features = false, optional = true, features = ["driver", "cublas", "nvrtc"] } +zip = { version = "0.6.6", default-features = false, optional = true } +cudarc = { version = "0.9.11", default-features = false, optional = true, features = ["driver", "cublas", "nvrtc"] } num-traits = { version = "0.2.15", default-features = false } safetensors = { version = "0.3", default-features = false, optional = true } memmap2 = { version = "0.5", default-features = false, optional = true } -half = { git = "https://github.com/starkat99/half-rs.git", branch = "main", optional = true, features = ["num-traits", "rand_distr"] } -gemm = { version = "0.15.3", default-features = false, optional = true } +half = { version = "2.3.1", optional = true, features = ["num-traits", "rand_distr"] } +gemm = { version = "0.15.4", default-features = false, optional = true } rayon = { version = "1.7.0", optional = true } [dev-dependencies] diff --git a/src/nn/convtrans.rs b/src/nn/convtrans.rs index e5f746c95..da00c05c3 100644 --- a/src/nn/convtrans.rs +++ b/src/nn/convtrans.rs @@ -13,17 +13,29 @@ pub mod builder { const KERNEL_SIZE: usize, const STRIDE: usize = 1, const PADDING: usize = 0, + const DILATION: usize = 1, + const GROUPS: usize = 1, >; } -impl - BuildOnDevice for builder::ConvTrans2D +impl< + const I: usize, + const O: usize, + const K: usize, + const S: usize, + const P: usize, + const L: usize, + const G: usize, + E, + D, + > BuildOnDevice for builder::ConvTrans2D where E: Dtype, D: Device, - ConvTrans2D: BuildModule, + Const<{ O / G }>: Sized, + ConvTrans2D: BuildModule, { - type Built = ConvTrans2D; + type Built = ConvTrans2D; fn try_build_on_device(device: &D) -> Result::Err> { Self::Built::try_build(device) } @@ -45,6 +57,9 @@ where /// - `KERNEL_SIZE`: The size of the kernel applied to both width and height of the images. /// - `STRIDE`: How far to move the kernel each step. Defaults to `1` /// - `PADDING`: How much zero padding to add around the images. Defaults to `0`. +/// - `DILATION`: Controls the spacing between kernel points. Defaults to `1`. +/// - `GROUPS`: Controls the connections between inputs and outputs. +/// `IN_CHAN` and `OUT_CHAN` must both be divisible by `GROUPS`. For example, #[derive(Debug, Clone)] pub struct ConvTrans2D< const IN_CHAN: usize, @@ -52,19 +67,33 @@ pub struct ConvTrans2D< const KERNEL_SIZE: usize, const STRIDE: usize, const PADDING: usize, + const DILATION: usize, + const GROUPS: usize, E: Dtype, D: Storage, -> { - pub weight: Tensor, E, D>, +> where + Const<{ OUT_CHAN / GROUPS }>: Sized, +{ + pub weight: Tensor, E, D>, } -impl - TensorCollection for ConvTrans2D +impl< + const I: usize, + const O: usize, + const K: usize, + const S: usize, + const P: usize, + const L: usize, + const G: usize, + E, + D, + > TensorCollection for ConvTrans2D where E: Dtype + Float + SampleUniform, D: Device, + Const<{ O / G }>: Sized, { - type To> = ConvTrans2D; + type To> = ConvTrans2D; fn iter_tensors>( visitor: &mut V, @@ -85,26 +114,58 @@ where } #[cfg(feature = "nightly")] -impl - Module for ConvTrans2D +impl< + const C: usize, + const O: usize, + const K: usize, + const S: usize, + const P: usize, + const L: usize, + const G: usize, + E, + D, + Img, + > Module for ConvTrans2D where E: Dtype, D: Device, - Img: TryConvTrans2DTo, E, D>, S, P> + HasErr, + Const<{ O / G }>: Sized, + (Img, Tensor, E, D>): + TryConvTrans2D, Const

, Const, Const>, { - type Output = Img::Output; - type Error = D::Err; + type Output = <(Img, Tensor, E, D>) as TryConvTrans2D< + Const, + Const

, + Const, + Const, + >>::Convolved; + type Error = <(Img, Tensor, E, D>) as TryConvTrans2D< + Const, + Const

, + Const, + Const, + >>::Error; - fn try_forward(&self, x: Img) -> Result { - x.try_convtrans2d_to(self.weight.clone()) + fn try_forward(&self, x: Img) -> Result { + (x, self.weight.clone()).try_convtrans2d(Const, Const, Const, Const) } } -impl - NonMutableModule for ConvTrans2D +impl< + const I: usize, + const O: usize, + const K: usize, + const S: usize, + const P: usize, + const L: usize, + const G: usize, + E, + D, + > NonMutableModule for ConvTrans2D where E: Dtype, D: Storage, + Const<{ O / G }>: Sized, { } @@ -187,7 +248,7 @@ mod tests { assert_ne!( g.get(&m.weight).array(), - [[[[TestDtype::zero(); 3]; 3]; 2]; 4] + [[[[TestDtype::zero(); 3]; 3]; 4]; 2] ); opt.update(&mut m, &g).expect("unused params"); diff --git a/src/nn/mod.rs b/src/nn/mod.rs index ddf87b9ed..17ae73fd7 100644 --- a/src/nn/mod.rs +++ b/src/nn/mod.rs @@ -190,6 +190,7 @@ mod batchnorm2d; mod bias2d; #[cfg(feature = "nightly")] mod conv; +#[cfg(feature = "nightly")] mod convtrans; mod dropout; mod ema; diff --git a/src/tensor_ops/conv2d/conv2d.cu b/src/tensor_ops/conv2d/conv2d.cu index bb14e53c6..504edcb87 100644 --- a/src/tensor_ops/conv2d/conv2d.cu +++ b/src/tensor_ops/conv2d/conv2d.cu @@ -22,21 +22,21 @@ __device__ void unfold_input_into_patches( const size_t *strides, // 4d image strides T *patches // 6d (Batch, Groups * Channels, KernelSize, KernelSize, HeightOut, WidthOut) ) { - const size_t n = op.batch * op.groups * op.chan_in * op.h_out * op.w_out; + const size_t n = op.batch * op.chan_in * op.h_out * op.w_out; for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; i += blockDim.x * gridDim.x) { unsigned int idx = i; const size_t ow = idx % op.w_out; idx /= op.w_out; const size_t oh = idx % op.h_out; idx /= op.h_out; - const size_t c = idx % (op.chan_in * op.groups); - idx /= (op.chan_in * op.groups); + const size_t c = idx % op.chan_in; + idx /= op.chan_in; const size_t b = idx % op.batch; const T *image_i = image + b * strides[0] + c * strides[1]; T *patches_i = patches + oh * op.w_out + ow; patches_i += c * (op.kernel * op.kernel * op.h_out * op.w_out); - patches_i += b * (op.groups * op.chan_in * op.kernel * op.kernel * op.h_out * op.w_out); + patches_i += b * (op.chan_in * op.kernel * op.kernel * op.h_out * op.w_out); T zero = 0.0; @@ -44,7 +44,7 @@ __device__ void unfold_input_into_patches( const size_t y = oh * op.stride + op.dilation * k1 - op.padding; for (int k2 = 0;k2 < op.kernel;k2++) { const size_t x = ow * op.stride + op.dilation * k2 - op.padding; - *patches_i = (y >= op.h_in || x >= op.w_in) ? zero : image[y * strides[2] + x * strides[3]]; + *patches_i = (y >= op.h_in || x >= op.w_in) ? zero : image_i[y * strides[2] + x * strides[3]]; patches_i += op.h_out * op.w_out; } } @@ -86,7 +86,7 @@ __device__ void unfold_output_into_patches( const size_t ow = ow_s / op.stride; const bool invalid = k1_invalid || (ow_ks < op.dilation * k2 || ow_s % op.stride != 0 || ow >= op.w_out); - *patches_i = invalid ? zero : image_out[oh * op.w_out + ow]; + *patches_i = invalid ? zero : image_i[oh * op.w_out + ow]; patches_i += op.h_in * op.w_in; } } @@ -96,12 +96,13 @@ __device__ void unfold_output_into_patches( template __device__ void transpose_filters( const Conv2DOp op, - const T *filters, // 4d (ChanOut, ChanIn, KernelSize, KernelSize) + const T *filters, // 4d (ChanOut, ChanIn/Groups, KernelSize, KernelSize) const size_t *strides, // 4d filters strides - T *filters_tr // 5d (Groups, ChanIn, ChanOut/Groups, KernelSize, KernelSize) + T *filters_tr // 5d (Groups, ChanIn/Groups, ChanOut/Groups, KernelSize, KernelSize) ) { - const size_t n = op.chan_in * op.chan_out * op.kernel * op.kernel; + const size_t c_per_g = op.chan_in / op.groups; const size_t o_per_g = op.chan_out / op.groups; + const size_t n = c_per_g * op.chan_out * op.kernel * op.kernel; for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; i += blockDim.x * gridDim.x) { unsigned int idx = i; @@ -109,18 +110,18 @@ __device__ void transpose_filters( idx /= op.kernel; const size_t k1 = idx % op.kernel; idx /= op.kernel; - const size_t c = idx % op.chan_in; - idx /= op.chan_in; + const size_t cg = idx % c_per_g; + idx /= c_per_g; const size_t o = idx % op.chan_out; const size_t og = o % o_per_g; const size_t g = o / o_per_g; - auto i_no = o * strides[0] + c * strides[1] + k1 * strides[2] + k2 * strides[3]; + auto i_no = o * strides[0] + cg * strides[1] + k1 * strides[2] + k2 * strides[3]; T *filters_tr_i = filters_tr + k2; filters_tr_i += k1 * op.kernel; filters_tr_i += og * (op.kernel * op.kernel); - filters_tr_i += c * (o_per_g * op.kernel * op.kernel); - filters_tr_i += g * (op.chan_in * o_per_g * op.kernel * op.kernel); + filters_tr_i += cg * (o_per_g * op.kernel * op.kernel); + filters_tr_i += g * (c_per_g * o_per_g * op.kernel * op.kernel); *filters_tr_i = filters[i_no]; } } @@ -128,12 +129,13 @@ __device__ void transpose_filters( template __device__ void sum_transposed_filters( const Conv2DOp op, - const T *filters_tr, // 6d (Batch, Groups, ChanIn, ChanOut/Groups, KernelSize, KernelSize) - T *filters, // 4d (ChanOut, ChanIn, KernelSize, KernelSize) + const T *filters_tr, // 6d (Batch, Groups, ChanIn/Groups, ChanOut/Groups, KernelSize, KernelSize) + T *filters, // 4d (ChanOut, ChanIn/Groups, KernelSize, KernelSize) const size_t *strides // 4d filter strides ) { - const size_t n = op.chan_out * op.chan_in * op.kernel * op.kernel; const size_t o_per_g = op.chan_out / op.groups; + const size_t c_per_g = op.chan_in / op.groups; + const size_t n = op.chan_out * c_per_g * op.kernel * op.kernel; for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; i += blockDim.x * gridDim.x) { unsigned int idx = i; @@ -141,19 +143,19 @@ __device__ void sum_transposed_filters( idx /= op.kernel; const size_t k1 = idx % op.kernel; idx /= op.kernel; - const size_t c = idx % op.chan_in; - idx /= op.chan_in; + const size_t cg = idx % c_per_g; + idx /= c_per_g; const size_t o = idx % op.chan_out; const size_t og = o % o_per_g; const size_t g = o / o_per_g; - auto i_no = o * strides[0] + c * strides[1] + k1 * strides[2] + k2 * strides[3]; + auto i_no = o * strides[0] + cg * strides[1] + k1 * strides[2] + k2 * strides[3]; const T *filters_tr_i = filters_tr + k2; filters_tr_i += k1 * op.kernel; filters_tr_i += og * (op.kernel * op.kernel); - filters_tr_i += c * (o_per_g * op.kernel * op.kernel); - filters_tr_i += g * (op.chan_in * o_per_g * op.kernel * op.kernel); + filters_tr_i += cg * (o_per_g * op.kernel * op.kernel); + filters_tr_i += g * (c_per_g * o_per_g * op.kernel * op.kernel); T tmp = 0.0; for (int b = 0; b < op.batch; b++) { diff --git a/src/tensor_ops/conv2d/cpu_kernel.rs b/src/tensor_ops/conv2d/cpu_kernel.rs index 7afade5c7..8342e77ec 100644 --- a/src/tensor_ops/conv2d/cpu_kernel.rs +++ b/src/tensor_ops/conv2d/cpu_kernel.rs @@ -54,7 +54,7 @@ impl Cpu { { { let mut i = 0; - for c in 0..(op.groups * op.chan_in) { + for c in 0..op.chan_in { for k1 in 0..op.kernel { for k2 in 0..op.kernel { for oh in 0..op.h_out { @@ -73,9 +73,11 @@ impl Cpu { } } - // (G, O / G, C * K * K) * (G, C * K * K, OH * OW) = (G, O / G, OH * OW) + // filters: (G, O/G, C/G*K*K) + // buf: (G, C/G*K*K, OH*OW) + // output: (G, O/G, OH*OW) let m = op.chan_out / op.groups; - let k = op.chan_in * op.kernel * op.kernel; + let k = (op.chan_in / op.groups) * op.kernel * op.kernel; let n = op.w_out * op.h_out; for g in 0..op.groups { Self::matmul( @@ -128,8 +130,8 @@ impl Cpu { { // img_g += filters^T * unfold(grad_out) - // (G, C, H * W) += (G, C, O/G * K * K) * (G, O/G * K * K, H * W) - let m = op.chan_in; + // (G, C/G, H * W) += (G, C/G, O/G * K * K) * (G, O/G * K * K, H * W) + let m = op.chan_in / op.groups; let k = (op.chan_out / op.groups) * op.kernel * op.kernel; let n = op.h_in * op.w_in; for g in 0..op.groups { @@ -148,8 +150,8 @@ impl Cpu { { // weight_g^T += img * unfold(patches)^T - // (G, C, O/G * K * K) += (G, C, H * W) * (G, H * W, O/G * K * K) - let m = op.chan_in; + // (G, C/G, O/G * K * K) += (G, C/G, H * W) * (G, H * W, O/G * K * K) + let m = op.chan_in / op.groups; let k = op.h_in * op.w_in; let n = (op.chan_out / op.groups) * op.kernel * op.kernel; for g in 0..op.groups { @@ -184,13 +186,7 @@ where rhs: &Tensor, out: &mut Tensor, ) -> Result<(), Self::Err> { - let patches = ( - op.groups * op.chan_in, - op.kernel, - op.kernel, - op.h_out, - op.w_out, - ); + let patches = (op.chan_in, op.kernel, op.kernel, op.h_out, op.w_out); let mut patches = self.try_alloc_zeros::(patches.num_elements())?; let [lstride, ostride] = match L::NUM_DIMS { 3 => [0; 2], @@ -224,7 +220,7 @@ where ) -> Result<(), Self::Err> { let f_tr_shape = [ op.groups, - op.chan_in, + op.chan_in / op.groups, op.chan_out / op.groups, op.kernel, op.kernel, @@ -238,9 +234,9 @@ where // transpose filters in f1023 let buf = rhs.data.as_ref(); let mut f_idx = NdIndex::new(f_tr_shape, f_tr_shape.strides()); - while let Some((i, [g, c, o, k1, k2])) = f_idx.next_with_idx() { - let idx = (g * (op.chan_out / op.groups) + o) * rhs.strides[0] - + c * rhs.strides[1] + while let Some((i, [g, c_over_g, o_over_g, k1, k2])) = f_idx.next_with_idx() { + let idx = (g * (op.chan_out / op.groups) + o_over_g) * rhs.strides[0] + + c_over_g * rhs.strides[1] + k1 * rhs.strides[2] + k2 * rhs.strides[3]; f1023[i] = buf[idx]; @@ -269,9 +265,9 @@ where { // untranspose filters let mut f_idx = NdIndex::new(f_tr_shape, f_tr_shape.strides()); - while let Some((i, [g, c, o, k1, k2])) = f_idx.next_with_idx() { - let idx = (g * (op.chan_out / op.groups) + o) * rhs.strides[0] - + c * rhs.strides[1] + while let Some((i, [g, c_over_g, o_over_g, k1, k2])) = f_idx.next_with_idx() { + let idx = (g * (op.chan_out / op.groups) + o_over_g) * rhs.strides[0] + + c_over_g * rhs.strides[1] + k1 * rhs.strides[2] + k2 * rhs.strides[3]; grad_rhs[idx] += grad_f1023[i]; diff --git a/src/tensor_ops/conv2d/cuda_kernel.rs b/src/tensor_ops/conv2d/cuda_kernel.rs index d22ea08ef..759437e2d 100644 --- a/src/tensor_ops/conv2d/cuda_kernel.rs +++ b/src/tensor_ops/conv2d/cuda_kernel.rs @@ -76,8 +76,7 @@ where self.dev.load_ptx(PTX_SRC.into(), Self::MOD, Self::FNS)?; } - let patches_item_numel = - op.groups * op.chan_in * op.kernel * op.kernel * op.h_out * op.w_out; + let patches_item_numel = op.chan_in * op.kernel * op.kernel * op.h_out * op.w_out; let patches_numel = op.batch * patches_item_numel; let mut patches = unsafe { self.get_workspace::(patches_numel) }?; @@ -89,16 +88,15 @@ where unsafe { let unfold_fn = self.dev.get_func(Self::MOD, Self::FNS[0]).unwrap(); - let cfg = - launch_cfg::<128>((op.batch * op.groups * op.chan_in * op.h_out * op.w_out) as u32); + let cfg = launch_cfg::<128>((op.batch * op.chan_in * op.h_out * op.w_out) as u32); let params = (op, img.data.as_ref(), &img_strides, &mut patches); unfold_fn.launch(cfg, params)?; - // LHS (G, O/G, C*K*K) - // RHS (B, G, C*K*K, OH*OW) + // LHS (G, O/G, C/G*K*K) + // RHS (B, G, C/G*K*K, OH*OW) // OUT (B, G, O/G, OH*OW) let m = op.chan_out / op.groups; - let k = op.chan_in * op.kernel * op.kernel; + let k = (op.chan_in / op.groups) * op.kernel * op.kernel; let n = op.h_out * op.w_out; if op.groups == 1 { // optimizing here for common case @@ -145,7 +143,11 @@ where ) -> Result<(), Self::Err> { let patches_item_numel = op.chan_out * op.kernel * op.kernel * op.h_in * op.w_in; let patches_numel = op.batch * patches_item_numel; - let filters_numel = op.chan_in * op.chan_out * op.kernel * op.kernel; + let filters_numel = op.groups + * (op.chan_in / op.groups) + * (op.chan_out / op.groups) + * op.kernel + * op.kernel; let mut patches = unsafe { self.get_workspace::(patches_numel) }?; let mut patches = unsafe { patches.transmute_mut::(patches_numel).unwrap() }; @@ -177,10 +179,10 @@ where self.par_stream.wait_for_default()?; // img_g += filters * patches - // LHS = (G, C, O/G*K*K) + // LHS = (G, C/G, O/G*K*K) // RHS = (B, G, O/G*K*K, H*W) - // OUT = (B, G, C, H*W) - let m = op.chan_in; + // OUT = (B, G, C/G, H*W) + let m = op.chan_in / op.groups; let k = (op.chan_out / op.groups) * op.kernel * op.kernel; let n = op.h_in * op.w_in; self.blas.set_stream(Some(self.par_stream.as_ref()))?; @@ -217,10 +219,10 @@ where unsafe { // weight_g += img * patches^T - // LHS = (B, G, C, H*W) + // LHS = (B, G, C/G, H*W) // RHS = (B, H*W, G, O/G*K*K) - // OUT = (B, G, C, O/G*K*K) - let m = op.chan_in; + // OUT = (B, G, C/G, O/G*K*K) + let m = op.chan_in / op.groups; let k = op.h_in * op.w_in; let n = (op.chan_out / op.groups) * op.kernel * op.kernel; if op.groups == 1 { diff --git a/src/tensor_ops/conv2d/mod.rs b/src/tensor_ops/conv2d/mod.rs index 2c34b8b9c..eef6ff179 100644 --- a/src/tensor_ops/conv2d/mod.rs +++ b/src/tensor_ops/conv2d/mod.rs @@ -285,8 +285,9 @@ where let (img, filters) = self; assert_eq!(img.shape.1.size(), filters.shape.1.size() * groups.size()); assert_eq!(filters.shape.2, filters.shape.3); - let (batch, _, h, w) = img.shape; - let (out_chan, inp_chan, kernel, _) = filters.shape; + let (batch, inp_chan, h, w) = img.shape; + let (out_chan, inp_chan_over_groups, kernel, _) = filters.shape; + assert_eq!(inp_chan / groups, inp_chan_over_groups); assert!(out_chan.size() % groups.size() == 0); if img.strides != img.shape.strides() || filters.strides != filters.shape.strides() { panic!("Image & filter inputs to conv2d must be contiguous"); diff --git a/src/tensor_ops/conv2d/tests.rs b/src/tensor_ops/conv2d/tests.rs index 375705bec..85603de8b 100644 --- a/src/tensor_ops/conv2d/tests.rs +++ b/src/tensor_ops/conv2d/tests.rs @@ -160,37 +160,37 @@ fn test_conv2d_stride_3_padding_4() { .tensor([-0.07123789, -0.17244765]) .to_dtype::(); #[rustfmt::skip] - let x = dev.tensor([[[0.69103152, 0.25624934],[-0.38448590, 0.03110456],[0.83753252, 0.53786588],[1.15540242, -0.54148245]]]).to_dtype::(); + let x = dev.tensor([[[0.69103152, 0.25624934],[-0.38448590, 0.03110456],[0.83753252, 0.53786588],[1.15540242, -0.54148245]]]).to_dtype::(); let result = (x.leaky_trace(), weight.clone()) .conv2d(Const::<3>, Const::<4>, Const::<1>, Const::<1>) + bias.leaky_trace().broadcast::<_, Axes2<1, 2>>(); #[rustfmt::skip] - assert_close_to_literal!( - result, - [ - [[-0.07123789, -0.07123789, -0.07123789],[-0.07123789, -0.14481398, -0.07123789],[-0.07123789, -0.59748650, -0.07123789],[-0.07123789, -0.07123789, -0.07123789]], - [[-0.17244765, -0.17244765, -0.17244765],[-0.17244765, -0.3061839, -0.17244765],[-0.17244765, -0.42046443, -0.17244765],[-0.17244765, -0.17244765, -0.17244765]], - ] - ); + assert_close_to_literal!( + result, + [ + [[-0.07123789, -0.07123789, -0.07123789],[-0.07123789, -0.14481398, -0.07123789],[-0.07123789, -0.59748650, -0.07123789],[-0.07123789, -0.07123789, -0.07123789]], + [[-0.17244765, -0.17244765, -0.17244765],[-0.17244765, -0.3061839, -0.17244765],[-0.17244765, -0.42046443, -0.17244765],[-0.17244765, -0.17244765, -0.17244765]], + ] + ); let g = result.exp().mean().backward(); #[rustfmt::skip] - assert_close_to_literal!( - g.get(&x), - [[[-0.009780421, 0.01484663],[0.010391434, 0.0062526874],[0.00032053515, -0.009087289],[-0.0073772445, 0.0105412705]]] - ); + assert_close_to_literal!( + g.get(&x), + [[[-0.009780421, 0.01484663],[0.010391434, 0.0062526874],[0.00032053515, -0.009087289],[-0.0073772445, 0.0105412705]]] + ); #[rustfmt::skip] - assert_close_to_literal!( - g.get(&weight), - [ - [[[0.0, 0.019200183, 0.012330416],[0.0, 0.051398464, -0.003175714],[0.0, -0.013860448, 0.0011212977]]], - [[[0.0, 0.02291844, 0.01471829],[0.0, 0.05281557, -0.0069562597],[0.0, -0.011794927, 0.00095419877]]], - ] - ); + assert_close_to_literal!( + g.get(&weight), + [ + [[[0.0, 0.019200183, 0.012330416],[0.0, 0.051398464, -0.003175714],[0.0, -0.013860448, 0.0011212977]]], + [[[0.0, 0.02291844, 0.01471829],[0.0, 0.05281557, -0.0069562597],[0.0, -0.011794927, 0.00095419877]]], + ] + ); assert_close_to_literal!(g.get(&bias), [0.44699076, 0.408709]); } @@ -208,11 +208,11 @@ fn test_conv2d_s4p3k2() { let out = out + bias.broadcast::<_, Axes2<1, 2>>(); #[rustfmt::skip] - assert_close_to_literal!(out, [ - [[-0.57176435, -0.57176435, -0.57176435],[-0.57176435, 1.0759051, 1.4307989],[-0.57176435, -0.86296344, -1.8794353]], - [[0.29306656, 0.29306656, 0.29306656],[0.29306656, 0.9771965, 1.467767],[0.29306656, -6.367015, -2.3370528]], - [[-0.19717735, -0.19717735, -0.19717735],[-0.19717735, 1.3412137, 2.9476144],[-0.19717735, 4.247249, -2.1779637]], - ]); + assert_close_to_literal!(out, [ + [[-0.57176435, -0.57176435, -0.57176435],[-0.57176435, 1.0759051, 1.4307989],[-0.57176435, -0.86296344, -1.8794353]], + [[0.29306656, 0.29306656, 0.29306656],[0.29306656, 0.9771965, 1.467767],[0.29306656, -6.367015, -2.3370528]], + [[-0.19717735, -0.19717735, -0.19717735],[-0.19717735, 1.3412137, 2.9476144],[-0.19717735, 4.247249, -2.1779637]], + ]); } #[test] @@ -279,7 +279,130 @@ fn test_conv2d_dilated() { } #[test] -fn test_conv2d_grouped_forward() { +fn test_conv2d_grouped() { + let dev: TestDevice = Default::default(); + #[rustfmt::skip] + let x = dev + .tensor([ + [ + [1.15955114, 0.68945795, 2.22777081, 0.97970307, 0.90339321], + [-0.98012513,1.70133829,-1.29199386,0.21341583,-0.26468879,], + [0.13577828,-1.12634408,-0.03244355,-2.79851842,-0.74048048,], + [0.40849358, 0.29827344, -1.52881527, 0.76061243, 0.19023405], + [-0.59098929,-0.73987025,0.50599074,0.29848158,-1.34820068,], + ], + [ + [-1.46728218,-2.37748837,-1.17776859,1.11394322,-1.15377915,], + [-1.12063479, 1.21246791, 0.60054827, 0.45333079, -1.00518465], + [-0.46830899,-0.16050071,0.73001051,-0.90739632,2.07482648,], + [0.20643917,-2.07686543,-0.70319396,-0.21572231,0.32948348,], + [-0.17758289, 1.68857682, 1.51658368, 0.36873341, -1.28670764], + ], + ]) + .to_dtype::(); + let w = dev + .tensor([ + [[ + [0.12595156, 0.31638023, -0.33176154], + [-0.25272560, -0.09877023, 0.23111811], + [0.16438398, -0.22974905, -0.30995807], + ]], + [[ + [0.06859669, -0.17031185, -0.01402727], + [-0.13474676, -0.13854985, 0.06477568], + [-0.23829469, -0.08642964, 0.25663486], + ]], + [[ + [-0.24697559, 0.25657538, -0.29573485], + [0.24141768, -0.06818637, -0.16537642], + [0.02836069, 0.24494585, -0.16264121], + ]], + [[ + [-0.15155213, 0.05579081, 0.04196584], + [0.12847015, 0.14483747, -0.20391724], + [-0.09464011, 0.19359276, -0.08733428], + ]], + ]) + .to_dtype::(); + + let y = (x.leaky_trace(), w.clone()).conv2d(Const::<1>, Const::<0>, Const::<1>, Const::<2>); + assert_close_to_literal!( + y, + [ + [ + [-0.30270016, 0.90332055, 1.40224683], + [1.38538218, -0.45968884, -0.37895066], + [-0.89852279, 0.69733238, 0.14387871] + ], + [ + [-0.19980414, -0.82927543, 0.16024134], + [-0.71867949, 0.56500238, 0.57026976], + [0.34083101, 0.39758790, 0.10535710] + ], + [ + [-0.52317154, 0.45419621, 0.65919733], + [-0.20108134, -0.41395274, 0.03387201], + [0.32824901, 0.43692166, -0.89364898] + ], + [ + [-0.10102371, 0.72741956, 0.11405103], + [-0.32978091, 0.21818593, -0.57262653], + [0.17301394, -0.19611263, -0.22273052] + ] + ] + ); + + let grads = y.exp().mean().backward(); + + #[rustfmt::skip] + assert_close_to_literal!( + grads.get(&x), + [ + [ + [0.00414525, 0.01208433, 0.02895186, 0.00725342, -0.03791252], + [0.00665885, 0.01424124, -0.06918646, -0.00789665, 0.02120677], + [-0.02782234,-0.01537140,0.01253848,-0.06255514,-0.03012300], + [0.00690058, -0.06163570, -0.06612907, 0.00661747, 0.01612737], + [-0.00744827,-0.00665466,-0.01195416,-0.01671995,-0.00202148] + ], + [ + [-0.00787102,-0.01389305,-0.00736389,0.00499044,-0.01457474], + [-0.00143798,0.01763375,0.00575301,-0.01384448,-0.02306994], + [-0.00838966,0.01256749,0.01653318,-0.01032873,-0.02182827], + [0.01230815,0.02214932,-0.00205885,-0.00525379,-0.01245476], + [-0.00203156,0.01489969,0.00401321,-0.00189943,-0.00379007] + ] + ] + ); + assert_close_to_literal!( + grads.get(&w), + [ + [[ + [0.15669884, 0.34341401, -0.11008842], + [-0.08243199, -0.26589566, -0.07729618], + [-0.08831018, -0.29109383, -0.44972605] + ]], + [[ + [0.07151295, -0.08670343, -0.06787194], + [-0.11750249, -0.15240794, -0.23142532], + [-0.10467961, -0.14933982, -0.04248949] + ]], + [[ + [-0.19401677, 0.03564095, -0.02681023], + [-0.01509627, -0.05853239, 0.00262788], + [0.05381295, 0.04880310, 0.13299918] + ]], + [[ + [-0.18418193, -0.04946327, 0.05997707], + [0.00057301, -0.00255429, 0.00476278], + [-0.01050586, 0.03911699, 0.04431859] + ]] + ] + ); +} + +#[test] +fn test_conv2d_grouped_slices() { const NUM_GROUPS: usize = 3; let dev: TestDevice = Default::default(); let x: Tensor, TestDtype, _> = dev.sample_normal(); diff --git a/src/tensor_ops/convtrans2d/convtrans2d.cu b/src/tensor_ops/convtrans2d/convtrans2d.cu index d6e842704..3027e8d8d 100644 --- a/src/tensor_ops/convtrans2d/convtrans2d.cu +++ b/src/tensor_ops/convtrans2d/convtrans2d.cu @@ -1,9 +1,11 @@ #include "cuda_fp16.h" struct Conv2DOp { + size_t kernel; size_t stride; size_t padding; - size_t kernel; + size_t dilation; + size_t groups; size_t batch; size_t chan_in; size_t chan_out; @@ -16,9 +18,9 @@ struct Conv2DOp { template __device__ void unfold_input_into_patches( const Conv2DOp op, - const T *image, // 4d (Batch, Channels, Height, Width) + const T *image, // 4d (Batch, Groups * Channels, Height, Width) const size_t *strides, // 4d image strides - T *patches // 6d (Batch, Channels, KernelSize, KernelSize, HeightOut, WidthOut) + T *patches // 6d (Batch, Groups * Channels, KernelSize, KernelSize, HeightOut, WidthOut) ) { unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; if (i >= op.batch * op.chan_in * op.h_out * op.w_out) { @@ -44,15 +46,15 @@ __device__ void unfold_input_into_patches( for (int k1 = 0;k1 < op.kernel;k1++) { const size_t y_ks = oh + op.padding; - const size_t y_s = y_ks - k1; + const size_t y_s = y_ks - op.dilation * k1; const size_t y = y_s / op.stride; - const bool k1_invalid = (y_ks < k1 || y_s % op.stride != 0 || y >= op.h_in); + const bool k1_invalid = (y_ks < op.dilation * k1 || y_s % op.stride != 0 || y >= op.h_in); for (int k2 = 0;k2 < op.kernel;k2++) { const size_t x_ks = ow + op.padding; - const size_t x_s = x_ks - k2; + const size_t x_s = x_ks - op.dilation * k2; const size_t x = x_s / op.stride; - const bool invalid = k1_invalid || (x_ks < k2 || x_s % op.stride != 0 || x >= op.w_in); + const bool invalid = k1_invalid || (x_ks < op.dilation * k2 || x_s % op.stride != 0 || x >= op.w_in); *patches = invalid ? zero : image[y * strides[2] + x * strides[3]]; patches += op.h_out * op.w_out; } @@ -87,9 +89,9 @@ __device__ void unfold_output_into_patches( T zero = 0.0; for (int k1 = 0;k1 < op.kernel;k1++) { - const size_t oh = y * op.stride + k1 - op.padding; + const size_t oh = y * op.stride + op.dilation * k1 - op.padding; for (int k2 = 0;k2 < op.kernel;k2++) { - const size_t ow = x * op.stride + k2 - op.padding; + const size_t ow = x * op.stride + op.dilation * k2 - op.padding; *patches = (oh >= op.h_out || ow >= op.w_out) ? zero : image_out[oh * op.w_out + ow]; patches += op.h_in * op.w_in; } @@ -99,67 +101,39 @@ __device__ void unfold_output_into_patches( template __device__ void transpose_filters( const Conv2DOp op, - const T *filters, // 4d (ChanOut, ChanIn, KernelSize, KernelSize) + const T *filters, // 4d (ChanIn, ChanOut/Groups, KernelSize, KernelSize) const size_t *strides, // 4d filters strides - T *filters_tr // 4d (ChanIn, ChanOut, KernelSize, KernelSize) + T *filters_tr // 5d (Groups, ChanOut/Groups, ChanIn/Groups, KernelSize, KernelSize) ) { unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; - if (i >= op.chan_in * op.chan_out * op.kernel * op.kernel) { + const size_t o_per_g = op.chan_out / op.groups; + const size_t c_per_g = op.chan_in / op.groups; + if (i >= op.groups * o_per_g * c_per_g * op.kernel * op.kernel) { return; } - unsigned int idx = i; - const size_t k2 = idx % op.kernel; - idx /= op.kernel; - const size_t k1 = idx % op.kernel; - idx /= op.kernel; - const size_t o = idx % op.chan_out; - idx /= op.chan_out; - const size_t c = idx % op.chan_in; - - auto i_no = o * strides[0] + c * strides[1] + k1 * strides[2] + k2 * strides[3]; - - filters_tr[i] = filters[i_no]; -} - -template -__device__ void sum_transposed_filters( - const Conv2DOp op, - const T *filters_tr, // 5d (Batch, ChanIn, ChanOut, KernelSize, KernelSize) - T *filters, // 4d (ChanOut, ChanIn, KernelSize, KernelSize) - const size_t *strides // 4d filter strides -) { - unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; - auto numel = op.chan_out * op.chan_in * op.kernel * op.kernel; - if (i >= numel) { - return; - } unsigned int idx = i; const size_t k2 = idx % op.kernel; idx /= op.kernel; const size_t k1 = idx % op.kernel; idx /= op.kernel; + const size_t og = idx % o_per_g; + idx /= o_per_g; const size_t c = idx % op.chan_in; - idx /= op.chan_in; - const size_t o = idx % op.chan_out; - idx /= op.chan_out; - - auto i_tr = c * (op.chan_out * op.kernel * op.kernel) + o * (op.kernel * op.kernel) + k1 * (op.kernel) + k2; - auto i_no = o * strides[0] + c * strides[1] + k1 * strides[2] + k2 * strides[3]; - - filters_tr += i_tr; - - T tmp = 0.0; - for (int b = 0; b < op.batch; b++) { - tmp += *filters_tr; - filters_tr += numel; - } - - filters[i_no] += tmp; + const size_t cg = c % c_per_g; + const size_t g = c / c_per_g; + + auto i_no = c * strides[0] + og * strides[1] + k1 * strides[2] + k2 * strides[3]; + filters_tr += k2; + filters_tr += k1 * op.kernel; + filters_tr += cg * (op.kernel * op.kernel); + filters_tr += og * (c_per_g * op.kernel * op.kernel); + filters_tr += g * (o_per_g * c_per_g * op.kernel * op.kernel); + *filters_tr = filters[i_no]; } -#define CONV_OP(TYPENAME, UNFOLD_INPUT, UNFOLD_OUTPUT, TR_FILTERS, SUM_TR_FILTERS) \ +#define CONV_OP(TYPENAME, UNFOLD_INPUT, UNFOLD_OUTPUT, TR_FILTERS) \ extern "C" __global__ void UNFOLD_INPUT( \ const Conv2DOp op, \ const TYPENAME *image, \ @@ -182,34 +156,23 @@ extern "C" __global__ void TR_FILTERS( \ TYPENAME *filters_tr \ ) { \ transpose_filters(op, filters, strides, filters_tr); \ -} \ -extern "C" __global__ void SUM_TR_FILTERS( \ - const Conv2DOp op, \ - const TYPENAME *filters_tr, \ - TYPENAME *filters, \ - const size_t *strides \ -) { \ - sum_transposed_filters(op, filters_tr, filters, strides); \ } CONV_OP( __half, unfold_input_into_patches_f16, unfold_output_into_patches_f16, - transpose_filters_f16, - sum_transposed_filters_f16 + transpose_filters_f16 ); CONV_OP( float, unfold_input_into_patches_f32, unfold_output_into_patches_f32, - transpose_filters_f32, - sum_transposed_filters_f32 + transpose_filters_f32 ); CONV_OP( double, unfold_input_into_patches_f64, unfold_output_into_patches_f64, - transpose_filters_f64, - sum_transposed_filters_f64 + transpose_filters_f64 ); \ No newline at end of file diff --git a/src/tensor_ops/convtrans2d/cpu_kernel.rs b/src/tensor_ops/convtrans2d/cpu_kernel.rs index 4d86fd1b6..f61943f84 100644 --- a/src/tensor_ops/convtrans2d/cpu_kernel.rs +++ b/src/tensor_ops/convtrans2d/cpu_kernel.rs @@ -1,6 +1,6 @@ use crate::prelude::Tensorlike; use crate::shapes::{Dtype, Shape}; -use crate::tensor::{cpu::*, Tensor}; +use crate::tensor::{cpu::*, Tensor, ZerosTensor}; use crate::tensor_ops::matmul::cpu_kernel::MatMulImpl; use std::sync::Arc; @@ -10,9 +10,9 @@ use super::{ConvTrans2DKernel, ConvTrans2DOp}; impl ConvTrans2DOp { #[inline(always)] fn unfold_idx(&self, [k1, k2, y, x]: [usize; 4]) -> Option<[usize; 2]> { - (y * self.stride + k1) + (y * self.stride + self.dilation * k1) .checked_sub(self.padding) - .zip((x * self.stride + k2).checked_sub(self.padding)) + .zip((x * self.stride + self.dilation * k2).checked_sub(self.padding)) .filter(|&(oh, ow)| oh < self.h_out && ow < self.w_out) .map(|(oh, ow)| [oh, ow]) } @@ -24,7 +24,7 @@ impl Cpu { &self, op: &ConvTrans2DOp, img: &[E], - filters: &[E], + filters_tr: &[E], out: &mut [E], buf: &mut [E], ) -> Result<(), CpuError> @@ -40,10 +40,10 @@ impl Cpu { for ow in 0..op.w_out { i += 1; let mut y = oh + op.padding; - if y < k1 { + if y < op.dilation * k1 { continue; } - y -= k1; + y -= op.dilation * k1; if y % op.stride != 0 { continue; } @@ -53,10 +53,10 @@ impl Cpu { } let mut x = ow + op.padding; - if x < k2 { + if x < op.dilation * k2 { continue; } - x -= k2; + x -= op.dilation * k2; if x % op.stride != 0 { continue; } @@ -75,20 +75,24 @@ impl Cpu { } } - // (O, C * K * K) * (C * K * K, OH * OW) = (O, OH * OW) - let m = op.chan_out; - let k = op.chan_in * op.kernel * op.kernel; + // filters_tr: (G, O/G, C/G*K*K) + // patches: (G, C/G*K*K, OH*OW) + // output: (G, O/G, OH*OW) + let m = op.chan_out / op.groups; + let k = (op.chan_in / op.groups) * op.kernel * op.kernel; let n = op.w_out * op.h_out; - Self::matmul( - (m, k, n), - false, - filters.as_ptr(), - [k, 1], - buf.as_ptr(), - [n, 1], - out.as_mut_ptr(), - [n, 1], - ); + for g in 0..op.groups { + Self::matmul( + (m, k, n), + false, + filters_tr[g * m * k..].as_ptr(), + [k, 1], + buf[g * k * n..].as_ptr(), + [n, 1], + out[g * m * n..].as_mut_ptr(), + [n, 1], + ); + } Ok(()) } @@ -99,8 +103,8 @@ impl Cpu { op: &ConvTrans2DOp, img: &[E], grad_img: &mut [E], - filters_tr: &[E], - grad_filters_tr: &mut [E], + filters: &[E], + grad_filters: &mut [E], grad_out: &[E], buf: &mut [E], ) -> Result<(), CpuError> @@ -127,39 +131,45 @@ impl Cpu { } { - // img_g += filters^T * unfold(grad_out) - // (C, H * W) += (C, O * K * K) * (O * K * K, H * W) - let m = op.chan_in; - let k = op.chan_out * op.kernel * op.kernel; + // filters: (G, C/G, O/G*K*K) + // buf: (G, O/G*K*K, H*W) + // grad_img: (G, C/G, H * W) + let m = op.chan_in / op.groups; + let k = (op.chan_out / op.groups) * op.kernel * op.kernel; let n = op.h_in * op.w_in; - Self::matmul( - (m, k, n), - true, - filters_tr.as_ptr(), - [k, 1], - buf.as_ptr(), - [n, 1], - grad_img.as_mut_ptr(), - [n, 1], - ); + for g in 0..op.groups { + Self::matmul( + (m, k, n), + true, + filters[g * m * k..].as_ptr(), + [k, 1], + buf[g * k * n..].as_ptr(), + [n, 1], + grad_img[g * m * n..].as_mut_ptr(), + [n, 1], + ); + } } { - // weight_g^T += img * patches^T - // (C, O * K * K) += (C, H * W) * (H * W, O * K * K) - let m = op.chan_in; + // img: (G, C/G, H * W) + // buf: (G, H * W, O/G * K * K) + // grad_filters: (G, C/G, O/G * K * K) + let m = op.chan_in / op.groups; let k = op.h_in * op.w_in; - let n = op.chan_out * op.kernel * op.kernel; - Self::matmul( - (m, k, n), - true, - img.as_ptr(), - [k, 1], - buf.as_ptr(), - [1, k], - grad_filters_tr.as_mut_ptr(), - [n, 1], - ); + let n = (op.chan_out / op.groups) * op.kernel * op.kernel; + for g in 0..op.groups { + Self::matmul( + (m, k, n), + true, + img[g * m * k..].as_ptr(), + [k, 1], + buf[g * k * n..].as_ptr(), + [1, k], + grad_filters[g * m * n..].as_mut_ptr(), + [n, 1], + ); + } } Ok(()) } @@ -169,6 +179,10 @@ impl ConvTrans2DKernel for Cpu where Self: MatMulImpl, { + fn alloc(&self, s: S) -> Result, Self::Err> { + self.try_zeros_like(&s) + } + fn forward( &self, op: ConvTrans2DOp, @@ -176,20 +190,42 @@ where rhs: &Tensor, out: &mut Tensor, ) -> Result<(), Self::Err> { - let mut patches = self.try_alloc_zeros::(op.inp_patches_shape().num_elements())?; + let patches = (op.chan_in, op.kernel, op.kernel, op.h_out, op.w_out); + let mut patches = self.try_alloc_zeros::(patches.num_elements())?; + let f_tr_shape = [ + op.groups, + op.chan_out / op.groups, + op.chan_in / op.groups, + op.kernel, + op.kernel, + ]; + let mut f_tr = self.try_alloc_zeros::(f_tr_shape.num_elements())?; + + { + // transpose filters in f1023 + let buf = rhs.data.as_ref(); + let mut f_idx = NdIndex::new(f_tr_shape, f_tr_shape.strides()); + while let Some((i, [g, o_over_g, c_over_g, k1, k2])) = f_idx.next_with_idx() { + let idx = (g * (op.chan_in / op.groups) + c_over_g) * rhs.strides[0] + + o_over_g * rhs.strides[1] + + k1 * rhs.strides[2] + + k2 * rhs.strides[3]; + f_tr[i] = buf[idx]; + } + } + let [lstride, ostride] = match L::NUM_DIMS { 3 => [0; 2], 4 => [lhs.strides[0], out.strides[0]], _ => unreachable!(), }; let lhs = lhs.data.as_ref(); - let rhs = rhs.data.as_ref(); let out = Arc::make_mut(&mut out.data); for i_batch in 0..op.batch { self.convtrans2d_forward( &op, &lhs[i_batch * lstride..], - rhs, + &f_tr, &mut out[i_batch * ostride..], &mut patches, )?; @@ -207,23 +243,8 @@ where out: &impl Tensorlike, grad_out: &Self::Vec, ) -> Result<(), Self::Err> { - let f_tr_shape = op.filters_tr_shape(); - let mut patches = self.try_alloc_zeros::(op.out_patches_shape().num_elements())?; - let mut f1023 = self.try_alloc_zeros::(f_tr_shape.num_elements())?; - let mut grad_f1023 = self.try_alloc_zeros::(f_tr_shape.num_elements())?; - - { - // transpose filters in f1023 - let buf = rhs.data.as_ref(); - let mut f_idx = NdIndex::new(f_tr_shape, f_tr_shape.strides()); - while let Some((i, [c, o, k1, k2])) = f_idx.next_with_idx() { - let idx = o * rhs.strides[0] - + c * rhs.strides[1] - + k1 * rhs.strides[2] - + k2 * rhs.strides[3]; - f1023[i] = buf[idx]; - } - } + let patches_shape = [op.chan_out, op.kernel, op.kernel, op.h_in, op.w_in]; + let mut patches = self.try_alloc_zeros::(patches_shape.num_elements())?; let [lstride, ostride] = match L::NUM_DIMS { 3 => [0; 2], @@ -232,30 +253,19 @@ where }; let lhs = lhs.data.as_ref(); + let rhs = rhs.data.as_ref(); for i_batch in 0..op.batch { self.convtrans2d_backward( &op, &lhs[i_batch * lstride..], &mut grad_lhs[i_batch * lstride..], - &f1023, - &mut grad_f1023, + rhs, + grad_rhs, &grad_out[i_batch * ostride..], &mut patches, )?; } - { - // untranspose filters - let mut f_idx = NdIndex::new(f_tr_shape, f_tr_shape.strides()); - while let Some((i, [c, o, k1, k2])) = f_idx.next_with_idx() { - let idx = o * rhs.strides[0] - + c * rhs.strides[1] - + k1 * rhs.strides[2] - + k2 * rhs.strides[3]; - grad_rhs[idx] += grad_f1023[i]; - } - } - Ok(()) } } diff --git a/src/tensor_ops/convtrans2d/cuda_kernel.rs b/src/tensor_ops/convtrans2d/cuda_kernel.rs index bb3a5a4a8..ba619f23c 100644 --- a/src/tensor_ops/convtrans2d/cuda_kernel.rs +++ b/src/tensor_ops/convtrans2d/cuda_kernel.rs @@ -24,7 +24,6 @@ impl HasCudaKernel for Cuda { "unfold_input_into_patches_f16", "unfold_output_into_patches_f16", "transpose_filters_f16", - "sum_transposed_filters_f16", ]; } @@ -34,7 +33,6 @@ impl HasCudaKernel for Cuda { "unfold_input_into_patches_f32", "unfold_output_into_patches_f32", "transpose_filters_f32", - "sum_transposed_filters_f32", ]; } @@ -44,7 +42,6 @@ impl HasCudaKernel for Cuda { "unfold_input_into_patches_f64", "unfold_output_into_patches_f64", "transpose_filters_f64", - "sum_transposed_filters_f64", ]; } @@ -61,6 +58,11 @@ where Self: HasCudaKernel, CudaBlas: Gemm, { + fn alloc(&self, shape: S) -> Result, Self::Err> { + let data = unsafe { self.alloc_empty::(shape.num_elements()) }?; + Ok(self.build_tensor(shape, shape.strides(), data)) + } + fn forward( &self, op: super::ConvTrans2DOp, @@ -76,28 +78,63 @@ where let mut patches = unsafe { self.get_workspace::(patches_numel) }?; let mut patches = unsafe { patches.transmute_mut::(patches_numel).unwrap() }; + let ftr_numel = op.groups + * (op.chan_out / op.groups) + * (op.chan_in / op.groups) + * op.kernel + * op.kernel; + let mut ftr = unsafe { self.alloc_empty::(ftr_numel) }?; + let img_strides = self.dev.htod_copy(make_4d::(lhs.strides).into())?; - let unfold_fn = self.dev.get_func(Self::MOD, Self::FNS[0]).unwrap(); - let cfg = launch_cfg::<128>((op.batch * op.chan_in * op.h_out * op.w_out) as u32); - let params = (op, lhs.data.as_ref(), &img_strides, &mut patches); - unsafe { unfold_fn.launch(cfg, params) }?; - - // (O, C * K * K) * (B, C * K * K, OH * OW) = (B, O, OH * OW) - let m = op.chan_out; - let k = op.chan_in * op.kernel * op.kernel; + let f_strides = self.dev.htod_copy(rhs.strides.into())?; + + let out_buf = Arc::get_mut(&mut out.data).unwrap(); + + // LHS (G, O/G, C/G*K*K) + // RHS (B, G, C/G*K*K, OH*OW) + // OUT (B, G, O/G, OH*OW) + let m = op.chan_out / op.groups; + let k = (op.chan_in / op.groups) * op.kernel * op.kernel; let n = op.h_out * op.w_out; unsafe { - self.gemm_batch( - (op.batch, m, k, n), - rhs.data.as_ref(), - [0, k, 1], - &patches, - [k * n, n, 1], - Default::default(), - Arc::get_mut(&mut out.data).unwrap(), - [m * n, n, 1], - ) - .unwrap(); + // generate patches for matmul + let unfold_fn = self.dev.get_func(Self::MOD, Self::FNS[0]).unwrap(); + let cfg = launch_cfg::<128>((op.batch * op.chan_in * op.h_out * op.w_out) as u32); + unfold_fn.launch(cfg, (op, lhs.data.as_ref(), &img_strides, &mut patches))?; + + // prepare filters for backward operations by + // swapping dims 0 and 1 and adding a batch dimension + let tr_fn = self.dev.get_func(Self::MOD, Self::FNS[2]).unwrap(); + let cfg = launch_cfg::<128>(rhs.shape.num_elements() as u32); + tr_fn.launch(cfg, (op, rhs.data.as_ref(), &f_strides, &mut ftr))?; + + if op.groups == 1 { + self.gemm_batch( + (op.batch, m, k, n), + &ftr, + [0, k, 1], + &patches, + [k * n, n, 1], + Default::default(), + out_buf, + [m * n, n, 1], + ) + .unwrap(); + } else { + for i_batch in 0..op.batch { + self.gemm_batch( + (op.groups, m, k, n), + &ftr, + [m * k, k, 1], + &patches.slice(i_batch * op.groups * k * n..), + [k * n, n, 1], + Default::default(), + &mut out_buf.slice_mut(i_batch * op.groups * m * n..), + [m * n, n, 1], + ) + .unwrap(); + } + } } Ok(()) @@ -114,17 +151,10 @@ where grad_out: &Self::Vec, ) -> Result<(), Self::Err> { let patches_numel = op.batch * op.chan_out * op.kernel * op.kernel * op.h_in * op.w_in; - let filters_numel = op.chan_in * op.chan_out * op.kernel * op.kernel; let mut patches = unsafe { self.get_workspace::(patches_numel) }?; let mut patches = unsafe { patches.transmute_mut::(patches_numel).unwrap() }; - let mut f_b1023 = unsafe { self.alloc_empty::(filters_numel) }?; - let mut grad_f_b1023 = unsafe { self.alloc_empty::(op.batch * filters_numel) }?; - let f_strides = self.dev.htod_copy(rhs.strides.into())?; - - self.par_stream.wait_for_default()?; - { // unfold grad_out into patches let unfold_fn = self.dev.get_func(Self::MOD, Self::FNS[1]).unwrap(); @@ -132,31 +162,25 @@ where unsafe { unfold_fn.launch(cfg, (op, grad_out, &mut patches)) }?; } - { - // prepare filters for backward operations by - // swapping dims 0 and 1 and adding a batch dimension - let tr_fn = self.dev.get_func(Self::MOD, Self::FNS[2]).unwrap(); - let cfg = launch_cfg::<128>(rhs.shape.num_elements() as u32); - unsafe { - tr_fn.launch_on_stream( - self.par_stream.as_ref(), - cfg, - (op, rhs.data.as_ref(), &f_strides, &mut f_b1023), - ) - }?; + let rhs_buf = rhs.data.as_ref(); + let lhs_buf = lhs.data.as_ref(); + unsafe { self.par_stream.wait_for_default()?; // img_g += filters * patches - // (B, C, H * W) += (B, C, O * K * K) * (B, O * K * K, H * W) - let m = op.chan_in; - let k = op.chan_out * op.kernel * op.kernel; + // LHS = (G, C/G, O/G*K*K) + // RHS = (B, G, O/G*K*K, H*W) + // OUT = (B, G, C/G, H*W) + let m = op.chan_in / op.groups; + let k = (op.chan_out / op.groups) * op.kernel * op.kernel; let n = op.h_in * op.w_in; - unsafe { - self.blas.set_stream(Some(self.par_stream.as_ref()))?; + self.blas.set_stream(Some(self.par_stream.as_ref()))?; + if op.groups == 1 { + // optimizing here for common case self.gemm_batch( (op.batch, m, k, n), - &f_b1023, + rhs_buf, [0, k, 1], &patches, [k * n, n, 1], @@ -165,35 +189,62 @@ where [m * n, n, 1], ) .unwrap(); - self.blas.set_stream(None)?; + } else { + for i_batch in 0..op.batch { + self.gemm_batch( + (op.groups, m, k, n), + rhs_buf, + [m * k, k, 1], + &patches.slice(i_batch * op.groups * k * n..), + [k * n, n, 1], + ::ONE, + &mut grad_lhs.slice_mut(i_batch * op.groups * m * n..), + [m * n, n, 1], + ) + .unwrap(); + } } + self.blas.set_stream(None)?; } - { + unsafe { // weight_g += img * patches^T - // (B, C, O * K * K) += (B, C, H * W) * (B, H * W, O * K * K) - let m = op.chan_in; + // LHS = (B, G, C/G, H*W) + // RHS = (B, H*W, G, O/G*K*K) + // OUT = (G, C/G, O/G*K*K) + let m = op.chan_in / op.groups; let k = op.h_in * op.w_in; - let n = op.chan_out * op.kernel * op.kernel; - unsafe { - self.gemm_batch( - (op.batch, m, k, n), - lhs.data.as_ref(), - [m * k, k, 1], - &patches, - [k * n, 1, k], - Default::default(), - &mut grad_f_b1023, - [m * n, n, 1], - ) - .unwrap(); + let n = (op.chan_out / op.groups) * op.kernel * op.kernel; + if op.groups == 1 { + // optimizing here for common case + for i_batch in 0..op.batch { + self.gemm( + (m, k, n), + &lhs_buf.slice(i_batch * m * k..), + [k, 1], + &patches.slice(i_batch * k * n..), + [1, k], + E::ONE, + grad_rhs, + [n, 1], + ) + .unwrap() + } + } else { + for i_batch in 0..op.batch { + self.gemm_batch( + (op.groups, m, k, n), + &lhs_buf.slice(i_batch * op.groups * m * k..), + [m * k, k, 1], + &patches.slice(i_batch * op.groups * k * n..), + [k * n, 1, k], + E::ONE, + grad_rhs, + [m * n, n, 1], + ) + .unwrap(); + } } - - // sum all the gradients collected in our broadcasted grad_f - // into grad_rhs - let sum_fn = self.dev.get_func(Self::MOD, Self::FNS[3]).unwrap(); - let cfg = launch_cfg::<128>(rhs.shape.num_elements() as u32); - unsafe { sum_fn.launch(cfg, (op, &grad_f_b1023, grad_rhs, &f_strides)) }?; } self.dev.wait_for(self.par_stream.as_ref())?; diff --git a/src/tensor_ops/convtrans2d/mod.rs b/src/tensor_ops/convtrans2d/mod.rs index a970f4aec..dc56a2fc8 100644 --- a/src/tensor_ops/convtrans2d/mod.rs +++ b/src/tensor_ops/convtrans2d/mod.rs @@ -3,14 +3,21 @@ mod cpu_kernel; #[cfg(feature = "cuda")] mod cuda_kernel; +#[cfg(test)] +mod tests; + use crate::{shapes::*, tensor::*}; +use super::ReshapeTo; + #[repr(C)] #[derive(Debug, Copy, Clone)] pub(super) struct ConvTrans2DOp { + pub kernel: usize, pub stride: usize, pub padding: usize, - pub kernel: usize, + pub dilation: usize, + pub groups: usize, pub batch: usize, pub chan_in: usize, pub chan_out: usize, @@ -20,38 +27,9 @@ pub(super) struct ConvTrans2DOp { pub w_out: usize, } -impl ConvTrans2DOp { - fn new(s: usize, p: usize, k: usize, [b, c, h_in, w_in]: [usize; 4], o: usize) -> Self { - Self { - stride: s, - padding: p, - kernel: k, - batch: b, - chan_in: c, - chan_out: o, - h_in, - h_out: (h_in - 1) * s - 2 * p + k, - w_in, - w_out: (w_in - 1) * s - 2 * p + k, - } - } - - #[rustfmt::skip] - pub(super) fn inp_patches_shape(&self) -> (usize, usize, usize, usize, usize) { - (self.chan_in, self.kernel, self.kernel, self.h_out, self.w_out) - } - - #[rustfmt::skip] - pub(super) fn out_patches_shape(&self) -> (usize, usize, usize, usize, usize) { - (self.chan_out, self.kernel, self.kernel, self.h_in, self.w_in) - } - - pub(super) fn filters_tr_shape(&self) -> (usize, usize, usize, usize) { - (self.chan_in, self.chan_out, self.kernel, self.kernel) - } -} - pub(super) trait ConvTrans2DKernel: Storage { + fn alloc(&self, s: S) -> Result, Self::Err>; + fn forward( &self, op: ConvTrans2DOp, @@ -73,137 +51,221 @@ pub(super) trait ConvTrans2DKernel: Storage { ) -> Result<(), Self::Err>; } -pub trait ConvTransAlgebra: Dim { - type Convolved: Dim; +pub trait TryConvTrans2D: Sized { + type Convolved; + type Error: std::fmt::Debug; - fn convolve_dim(&self) -> Self::Convolved; + /// Applies a 2D convolution to the input tensor. + fn convtrans2d( + self, + stride: Stride, + padding: Padding, + dilation: Dilation, + groups: Groups, + ) -> Self::Convolved { + self.try_convtrans2d(stride, padding, dilation, groups) + .unwrap() + } + + /// Fallibly applies a 2D convolution to the input tensor. + fn try_convtrans2d( + self, + stride: Stride, + padding: Padding, + dilation: Dilation, + groups: Groups, + ) -> Result; } -impl ConvTransAlgebra - for Const +impl< + const KERNEL: usize, + const STRIDE: usize, + const PADDING: usize, + const DILATION: usize, + Groups: Dim, + const DIM: usize, + > TryConvTrans2D, Const, Const, Groups> + for (Const, Const) where - Const<{ D * S + K - S - 2 * P }>: Sized, + Const<{ (DIM - 1) * STRIDE - 2 * PADDING + DILATION * (KERNEL - 1) + 1 }>: Sized, { - type Convolved = Const<{ D * S + K - S - 2 * P }>; + type Convolved = Const<{ (DIM - 1) * STRIDE - 2 * PADDING + DILATION * (KERNEL - 1) + 1 }>; + type Error = std::convert::Infallible; - fn convolve_dim(&self) -> Self::Convolved { - Default::default() + fn try_convtrans2d( + self, + _: Const, + _: Const, + _: Const, + _: Groups, + ) -> Result { + Ok(Const) } } -impl ConvTransAlgebra for usize { +impl + TryConvTrans2D for (usize, Kernel) +{ type Convolved = usize; + type Error = std::convert::Infallible; - fn convolve_dim(&self) -> Self::Convolved { - (self * S + K).checked_sub(S + 2 * P).unwrap() - } -} - -pub trait TryConvTrans2DTo: HasErr { - type Output; - fn convtrans2d_to(self, filters: F) -> Self::Output { - self.try_convtrans2d_to(filters).unwrap() - } - fn try_convtrans2d_to(self, filters: F) -> Result; -} - -pub trait TryConvTrans2D { - fn convtrans2d(self, filters: F) -> Self::Output - where - Self: TryConvTrans2DTo, - { - self.convtrans2d_to(filters) - } - fn try_convtrans2d( + fn try_convtrans2d( self, - filters: F, - ) -> Result - where - Self: TryConvTrans2DTo, - { - self.try_convtrans2d_to(filters) + stride: Stride, + padding: Padding, + dilation: Dilation, + _: Groups, + ) -> Result { + let (dim, kernel) = self; + Ok( + ((dim - 1) * stride.size() + dilation.size() * (kernel.size() - 1) + 1) + .checked_sub(2 * padding.size()) + .unwrap(), + ) } } -impl, T, F> TryConvTrans2D for Tensor {} - -impl< - const C: usize, - H: Dim + ConvTransAlgebra, - W: Dim + ConvTransAlgebra, - const O: usize, - const K: usize, - const S: usize, - const P: usize, - E: Dtype, - D: ConvTrans2DKernel + ZerosTensor, - T: 'static + Tape, - > TryConvTrans2DTo, E, D>, S, P> - for Tensor<(Const, H, W), E, D, T> +impl + TryConvTrans2D + for ( + Tensor<(InpChan, H, W), E, D, T>, + Tensor<(InpChan, OutChanOverGroups, Kernel, Kernel), E, D>, + ) +where + InpChan: Dim, + OutChanOverGroups: Dim, + Kernel: Dim, + Stride: Dim, + Padding: Dim, + Dilation: Dim, + Groups: Dim, + H: Dim, + W: Dim, + E: Dtype, + D: ConvTrans2DKernel + crate::tensor_ops::reshape_to::ReshapeKernel, + T: Tape, + OutChanOverGroups: std::ops::Mul, + >::Output: Dim, + (H, Kernel): TryConvTrans2D, + (W, Kernel): TryConvTrans2D, + <(H, Kernel) as TryConvTrans2D>::Convolved: Dim, + <(W, Kernel) as TryConvTrans2D>::Convolved: Dim, { - type Output = Tensor<(Const, H::Convolved, W::Convolved), E, D, T>; - - fn try_convtrans2d_to( + type Convolved = Tensor< + ( + >::Output, + <(H, Kernel) as TryConvTrans2D>::Convolved, + <(W, Kernel) as TryConvTrans2D>::Convolved, + ), + E, + D, + T, + >; + type Error = D::Err; + + fn try_convtrans2d( self, - filters: Tensor, E, D>, - ) -> Result { - let h = self.shape.1; - let w = self.shape.2; - - let op = ConvTrans2DOp::new(S, P, K, [1, C, h.size(), w.size()], O); - let (lhs, ltape) = self.split_tape(); - let (rhs, rtape) = filters.split_tape(); - let mut tape = ltape.merge(rtape); - let mut out = lhs - .device - .try_zeros_like(&(Const, h.convolve_dim(), w.convolve_dim()))?; - lhs.device.forward(op, &lhs, &rhs, &mut out)?; - let lhs_ghost = lhs.ghost(); - let rhs_ghost = rhs.ghost(); - let out_ghost = out.ghost(); - tape.add_backward_op(move |grads| { - grads.try_alloc_for(&rhs_ghost)?; - grads.try_alloc_for(&lhs_ghost)?; - grads.try_alloc_for(&out_ghost)?; - let (grad_lhs, grad_rhs, grad_out) = - grads.muts_and_ref(&lhs_ghost, &rhs_ghost, &out_ghost); - lhs.device - .backward(op, &lhs, grad_lhs, &rhs, grad_rhs, &out_ghost, grad_out) - }); - Ok(out.put_tape(tape)) + stride: Stride, + padding: Padding, + dilation: Dilation, + groups: Groups, + ) -> Result { + let (img, filters) = self; + let (inp_chan, h, w) = img.shape; + let img = img.try_reshape_like(&(Const::<1>, inp_chan, h, w))?; + let out = (img, filters).try_convtrans2d(stride, padding, dilation, groups)?; + let (_, out_chan, out_h, out_w) = out.shape; + out.try_reshape_like(&(out_chan, out_h, out_w)) } } - impl< - B: Dim, - const C: usize, - H: Dim + ConvTransAlgebra, - W: Dim + ConvTransAlgebra, - const O: usize, - const K: usize, - const S: usize, - const P: usize, - E: Dtype, - D: ConvTrans2DKernel + ZerosTensor, - T: 'static + Tape, - > TryConvTrans2DTo, E, D>, S, P> - for Tensor<(B, Const, H, W), E, D, T> + InpChan, + OutChanOverGroups, + Kernel, + Stride, + Padding, + Dilation, + Groups, + Batch, + H, + W, + E, + D, + T, + > TryConvTrans2D + for ( + Tensor<(Batch, InpChan, H, W), E, D, T>, + Tensor<(InpChan, OutChanOverGroups, Kernel, Kernel), E, D>, + ) +where + InpChan: Dim, + OutChanOverGroups: Dim, + Kernel: Dim, + Stride: Dim, + Padding: Dim, + Dilation: Dim, + Groups: Dim, + Batch: Dim, + H: Dim, + W: Dim, + E: Dtype, + D: ConvTrans2DKernel, + T: Tape, + OutChanOverGroups: std::ops::Mul, + >::Output: Dim, + (H, Kernel): TryConvTrans2D, + (W, Kernel): TryConvTrans2D, + <(H, Kernel) as TryConvTrans2D>::Convolved: Dim, + <(W, Kernel) as TryConvTrans2D>::Convolved: Dim, { - type Output = Tensor<(B, Const, H::Convolved, W::Convolved), E, D, T>; - fn try_convtrans2d_to( + type Convolved = Tensor< + ( + Batch, + >::Output, + <(H, Kernel) as TryConvTrans2D>::Convolved, + <(W, Kernel) as TryConvTrans2D>::Convolved, + ), + E, + D, + T, + >; + type Error = D::Err; + + fn try_convtrans2d( self, - filters: Tensor, E, D>, - ) -> Result { - let h = self.shape.2; - let w = self.shape.3; - - let batch = self.shape().0; - let op = ConvTrans2DOp::new(S, P, K, [batch.size(), C, h.size(), w.size()], O); - let (lhs, ltape) = self.split_tape(); + stride: Stride, + padding: Padding, + dilation: Dilation, + groups: Groups, + ) -> Result { + let (img, filters) = self; + assert_eq!(img.shape.1, filters.shape.0); + assert_eq!(filters.shape.2, filters.shape.3); + let (batch, _, h, w) = img.shape; + let (inp_chan, out_chan_over_groups, kernel, _) = filters.shape; + let out_chan = out_chan_over_groups * groups; + if img.strides != img.shape.strides() || filters.strides != filters.shape.strides() { + panic!("Image & filter inputs to conv2d must be contiguous"); + } + let h_out = (h, kernel).convtrans2d(stride, padding, dilation, groups); + let w_out = (w, kernel).convtrans2d(stride, padding, dilation, groups); + let op = ConvTrans2DOp { + stride: stride.size(), + padding: padding.size(), + kernel: kernel.size(), + dilation: dilation.size(), + groups: groups.size(), + batch: batch.size(), + chan_in: inp_chan.size(), + chan_out: out_chan.size(), + h_in: h.size(), + h_out: h_out.size(), + w_in: w.size(), + w_out: w_out.size(), + }; + let (lhs, ltape) = img.split_tape(); let (rhs, rtape) = filters.split_tape(); - let mut out = - lhs.device - .try_zeros_like(&(batch, Const, h.convolve_dim(), w.convolve_dim()))?; + let mut out = lhs.device.alloc((batch, out_chan, h_out, w_out))?; let mut tape = ltape.merge(rtape); lhs.device.forward(op, &lhs, &rhs, &mut out)?; let lhs_ghost = lhs.ghost(); @@ -221,158 +283,3 @@ impl< Ok(out.put_tape(tape)) } } - -#[cfg(test)] -mod tests { - use super::*; - use crate::{tensor_ops::*, tests::*}; - - #[test] - /// TODO - /// Produced by - /// ```python - /// x = torch.rand((2, 3, 3), requires_grad=True) - /// w = torch.rand((2, 3, 3, 3), requires_grad=True) - /// print(x) - /// print(torch.swapaxes(w, 0, 1)) - /// y = torch.conv_transpose2d(x, w) - /// print(y) - /// y.exp().mean().backward() - /// print(x.grad) - /// print(torch.swapaxes(w.grad, 0, 1)) - /// ``` - fn convtrans2d_test() { - let device = TestDevice::default(); - - #[rustfmt::skip] - let x = device.tensor([ - [[0.0907329, 0.5784497, 0.1818193], [0.9867508, 0.0566732, 0.9057426], [0.3095418, 0.7836370, 0.3519793]], - [[0.0969202, 0.6929486, 0.7536632], [0.8163304, 0.4960053, 0.4525285], [0.2100813, 0.2163504, 0.3710884]], - ]); - #[rustfmt::skip] - let w = device.tensor([ - [[[0.3401043, 0.4955706, 0.2441649], [0.4701799, 0.6495003, 0.2529446], [0.4017784, 0.1768293, 0.2353096]], [[0.6621207, 0.7709801, 0.5986264], [0.2237560, 0.6466616, 0.7321741], [0.1578425, 0.9478191, 0.9861543]]], - [[[0.8291070, 0.0848221, 0.3680936], [0.4642293, 0.1073243, 0.1073309], [0.7863810, 0.3699800, 0.4956312]], [[0.0681600, 0.5616951, 0.4053129], [0.1850831, 0.8223089, 0.0667553], [0.8905262, 0.6328429, 0.8180532]]], - [[[0.7582999, 0.9763424, 0.5727801], [0.3743349, 0.4793805, 0.6885015], [0.8183323, 0.1882774, 0.9794642]], [[0.5606869, 0.7552301, 0.6572021], [0.8761331, 0.2401637, 0.1778120], [0.2065960, 0.4133974, 0.8821540]]], - ]); - let y = x.leaky_trace().convtrans2d::<1, 0>(w.clone()); - #[rustfmt::skip] - assert_close_to_literal!( - y, - [ - [[0.0950315, 0.7752370, 1.4619386, 1.2272182, 0.4955567], [0.9404547, 2.0147018, 2.9196219, 2.3676410, 1.0898490], [0.9427375, 2.4812443, 3.9218845, 3.6057489, 1.6545289], [0.7178541, 1.8030399, 3.0822182, 1.9167527, 1.0201255], [0.1575270, 0.6028528, 0.8236330, 0.8117172, 0.4487746]], - [[0.0818333, 0.5889638, 0.7130897, 0.9325359, 0.3723960], [0.9338222, 1.1092455, 2.6313026, 1.3005096, 0.5866395], [1.0377907, 2.8708880, 3.3737209, 2.5207422, 1.1140431], [1.6855066, 1.9777625, 3.1483138, 1.4968101, 0.8816571], [0.4305007, 1.0563757, 1.3593760, 0.9304471, 0.4780220]], - [[0.1231446, 0.9889490, 1.7642095, 1.5334388, 0.5994515], [1.3248383, 2.7914243, 3.7239599, 2.3741565, 1.0753872], [1.5313777, 2.9749527, 4.2994099, 3.4086916, 1.9924896], [1.2760720, 1.3538387, 3.8719988, 1.6865263, 1.5946647], [0.2967100, 0.8310994, 1.0901904, 1.1780756, 0.6721083]], - ] - ); - - let g = y.exp().mean().backward(); - - #[rustfmt::skip] - assert_close_to_literal!( - g.get(&x), - [ - [[2.4066830, 2.4581399, 2.5645943], [3.0028410, 3.2547507, 3.4216807], [2.1464431, 3.0581608, 2.8662176]], - [[2.9864695, 3.5932014, 2.5797451], [3.3677268, 3.8909531, 3.2242548], [2.4629762, 3.3527191, 2.9590628]], - ] - ); - #[rustfmt::skip] - assert_close_to_literal!( - g.get(&w), - [ - [[[0.6642238, 1.0354118, 0.9408946], [0.9326871, 1.1026800, 1.0413336], [0.5472590, 0.7134937, 0.7234858]], [[0.5456561, 0.7068147, 0.6173539], [0.8016681, 1.0878984, 1.0714644], [0.8350498, 1.1260045, 0.7775879]]], - [[[0.5567597, 0.5549879, 0.4975571], [0.6702054, 0.8184303, 0.6338357], [0.6227797, 0.5077031, 0.5278049]], [[0.3733614, 0.3889205, 0.3597363], [0.6457027, 0.7389204, 0.5783513], [0.7389930, 0.7089815, 0.5071381]]], - [[[1.1678052, 1.4273405, 1.2900156], [1.4765850, 1.5869446, 1.4983673], [1.0089380, 0.8733283, 1.0910161]], [[0.9175905, 1.0371233, 0.9381008], [1.4550014, 1.5706275, 1.4026034], [1.3066854, 1.4330946, 1.0638479]]], - ] - ); - } - - #[test] - /// torch.set_printoptions(precision=7) - /// x = torch.rand((2, 3, 3), requires_grad=True) - /// w = torch.rand((2, 3, 3, 3), requires_grad=True) - /// print(x) - /// print(torch.swapaxes(w, 0, 1)) - /// y = torch.conv_transpose2d(x, w, stride=2) - /// print(y) - /// y.exp().mean().backward() - /// print(x.grad) - /// print(torch.swapaxes(w.grad, 0, 1)) - fn convtrans2d_s2() { - let device = TestDevice::default(); - - #[rustfmt::skip] - let x = device.tensor([ - [[0.0357635, 0.0225288, 0.3642959],[0.6850907, 0.2586224, 0.2234361],[0.9315249, 0.7850553, 0.7588840]], - [[0.1240514, 0.4712945, 0.8732865],[0.1023245, 0.1211519, 0.0407664],[0.5106173, 0.9263544, 0.3101138]], - ]); - #[rustfmt::skip] - let w = device.tensor([ - [[[0.7007528, 0.1896583, 0.9991148],[0.6587640, 0.9383754, 0.3999129],[0.0035173, 0.4376699, 0.3985791],],[[0.2180834, 0.4829719, 0.0272914],[0.2712103, 0.8577049, 0.8002768],[0.7074867, 0.4011419, 0.8835942],],], - [[[0.0067961, 0.4006048, 0.3549793],[0.5392876, 0.3803764, 0.6090584],[0.4874769, 0.5006863, 0.8963661],],[[0.3751084, 0.5425243, 0.5102475],[0.6024926, 0.2719866, 0.9794098],[0.2236674, 0.1083973, 0.4948432],],], - [[[0.9486710, 0.9823384, 0.5994584],[0.2740490, 0.2620903, 0.2716798],[0.3620688, 0.9108542, 0.9017550],],[[0.6089512, 0.4252676, 0.2729263],[0.8855131, 0.3937372, 0.3419960],[0.8216078, 0.6664743, 0.5395248],],], - ]); - let y = x.leaky_trace().convtrans2d::<2, 0>(w.clone()); - #[rustfmt::skip] - assert_close_to_literal!( - y, - [ - [[0.0521149, 0.0666962, 0.1576860, 0.2318948, 0.4811018, 0.4908646, 0.3878066,],[0.0572037, 0.1399591, 0.2562388, 0.4253721, 0.8630048, 1.0908685, 0.8445575,],[0.5902850, 0.2447678, 1.3523079, 0.3064790, 1.4716963, 0.5718186, 1.1411824,],[0.4790645, 0.7306364, 0.5590932, 0.3465975, 0.3586293, 0.2446325, 0.1219794,],[0.8389287, 0.7641755, 2.1468871, 0.7580858, 1.6488208, 0.4078493, 0.8917535,],[0.7521397, 1.3120791, 1.5495670, 1.5312154, 1.6393251, 0.9781042, 0.5516644,],[0.3645314, 0.6125304, 1.4806095, 0.7151946, 1.3534986, 0.4565403, 0.5764900,],], - [[0.0467758, 0.0816279, 0.2529318, 0.2647139, 0.5785270, 0.6197178, 0.5749097,],[0.0940269, 0.0473439, 0.4393802, 0.1367552, 1.1979207, 0.3760918, 1.0771828,],[0.0882189, 0.3613173, 0.5524452, 0.2317001, 0.7967559, 0.3886862, 0.8587984,],[0.4311106, 0.2884232, 0.7299427, 0.1313255, 0.4212312, 0.0960777, 0.1760126,],[0.5547202, 1.0043029, 1.7619288, 0.9596879, 1.2826418, 0.5885472, 0.6480764,],[0.8100029, 0.4932112, 2.0489488, 0.5505725, 1.9815230, 0.3730083, 0.7659331,],[0.5683053, 0.5217513, 1.6775545, 0.4934808, 1.6013979, 0.4135783, 0.8336955,],], - [[0.1094690, 0.0878869, 0.3636634, 0.2225572, 1.0195196, 0.7292423, 0.4567231,],[0.1196501, 0.0582169, 0.4756527, 0.1914707, 1.0404429, 0.4393238, 0.3976323,],[0.8271067, 0.8317585, 1.2522885, 0.6402028, 1.5488806, 1.1506698, 0.9447322,],[0.2783581, 0.2198445, 0.3992766, 0.1154844, 0.2090276, 0.0746117, 0.0746450,],[1.5267723, 1.8244361, 2.8728042, 1.4814504, 2.0451818, 1.1080496, 0.7630367,],[0.7074417, 0.4451926, 1.4631481, 0.5704955, 1.0126743, 0.3209994, 0.3122311,],[0.7568032, 1.1887968, 2.1608419, 1.3324623, 1.7372788, 0.8979155, 0.8516415,],], - ] - ); - - let g = y.exp().mean().backward(); - - #[rustfmt::skip] - assert_close_to_literal!( - g.get(&x), - [ - [[0.1513395, 0.1986136, 0.2298895],[0.3779295, 0.3469064, 0.2452929],[0.4825282, 0.5639746, 0.3148936],], - [[0.1527605, 0.2144486, 0.2480491],[0.3177541, 0.3597765, 0.2327209],[0.3772507, 0.5048490, 0.2865718],], - ] - ); - #[rustfmt::skip] - assert_close_to_literal!( - g.get(&w), - [ - [[[0.1134962, 0.0483045, 0.1292279],[0.0842974, 0.0839551, 0.0851499],[0.0981565, 0.0517171, 0.1198249],],[[0.0928453, 0.0412302, 0.0897282],[0.0699945, 0.0741750, 0.0777097],[0.0907740, 0.0422520, 0.0901646],],], - [[[0.0771443, 0.0567608, 0.0866109],[0.1149883, 0.0410739, 0.1213422],[0.0952494, 0.0514846, 0.1152932],],[[0.0687711, 0.0483170, 0.0680406],[0.1007998, 0.0350681, 0.1096105],[0.0770078, 0.0379331, 0.0848609],],], - [[[0.1948610, 0.1028565, 0.1976164],[0.0683245, 0.0401628, 0.0643963],[0.1662624, 0.1036348, 0.2046718],],[[0.1715550, 0.0769297, 0.1411840],[0.0654903, 0.0355645, 0.0568618],[0.1351082, 0.0760437, 0.1234931],],], - ] - ); - } - - #[test] - fn test_batched_convtrans2d() { - let dev: TestDevice = Default::default(); - let x: Tensor, TestDtype, _> = dev.sample_normal(); - let w: Tensor, TestDtype, _> = dev.sample_normal(); - - let y: Tensor, _, _, _> = x.leaky_trace().convtrans2d::<3, 2>(w.clone()); - let y0 = y.retaped::(); - let grads0 = y.square().mean().backward(); - let x0 = grads0.get(&x); - let w0 = grads0.get(&w); - - let x = x - .broadcast::, _>() - .reshape::>(); - - let y: Tensor, _, _, _> = - x.leaky_trace().convtrans2d::<3, 2>(w.clone()); - for i in 0..10 { - assert_close_to_tensor!(y0, y.retaped::().select(dev.tensor(i)), 1e-5); - } - - let grads = y.square().mean().backward(); - - assert_close_to_tensor!(w0, grads.get(&w)); - - let x_grad = grads.get(&x) * 10.0; - for i in 0..10 { - assert_close_to_tensor!(x0, x_grad.clone().select(dev.tensor(i))); - } - } -} diff --git a/src/tensor_ops/convtrans2d/tests.rs b/src/tensor_ops/convtrans2d/tests.rs new file mode 100644 index 000000000..3d64acbf0 --- /dev/null +++ b/src/tensor_ops/convtrans2d/tests.rs @@ -0,0 +1,516 @@ +use super::*; +use crate::{tensor_ops::*, tests::*}; + +#[test] +fn test_convtrans2d_default() { + let device = TestDevice::default(); + + let x = device + .tensor([ + [ + [-1.44135797, 0.23273671, 0.55838293, 1.09627271], + [0.05642751, 0.96609902, 0.24707083, -1.48412001], + [0.40077326, 1.16620362, 0.48770329, -1.01286852], + ], + [ + [0.32589695, -0.91695106, -0.13670059, 0.23979346], + [0.72553939, -0.38209674, 0.35545620, -0.37955058], + [-0.43962145, -0.69825196, -1.05400932, 1.22050178], + ], + ]) + .to_dtype::(); + let w = device + .tensor([ + [ + [[-0.46923456, -0.05046696], [0.30552489, 0.68838942]], + [[-0.45128539, 0.28081128], [-1.64214528, -0.67228431]], + [[0.82240576, 0.06074515], [-0.19825625, -0.28086993]], + ], + [ + [[-1.61626124, 0.44280547], [0.03486229, 0.80007958]], + [[-0.69502026, 0.22409980], [0.58573663, -0.36501792]], + [[1.57198501, 0.84847665], [0.52653229, -0.59273601]], + ], + ]) + .to_dtype::(); + let y = + (x.leaky_trace(), w.clone()).convtrans2d(Const::<1>, Const::<0>, Const::<1>, Const::<1>); + #[rustfmt::skip] + assert_close_to_literal!( + y, + [ + [ + [0.14960037, 1.58987427, -0.45884517, -0.99068964, 0.05085630], + [-1.62814808,-0.20966601,-1.31598091,2.07309437,0.85334837], + [0.56502044, 1.26762176, 1.55388546, -2.00090408, -0.73376185], + [0.10711999, 0.25611749, 0.35640538, -0.77446860, 0.27925056] + ], + [ + [0.42395884, 0.20055276, -0.29711384, -0.53522754, 0.36158341], + [2.02807951,-0.06121790,-0.99166316,-0.90268230,-1.32635069], + [0.45699552, -2.14002204, -0.02407408, 1.42854285, 1.12538254], + [-0.91563028,-2.43303156,-1.94739747,2.43502688,0.23543060] + ], + [ + [-0.67307591,-1.06106889,-0.51954788,1.19646454,0.27005240], + [1.64429677, 0.49562341, 0.79191101, -1.66748166, -0.86223716], + [0.00935271, -1.32583618, -1.68409586, 0.03524891, 1.61585832], + [-0.31093070,-0.45084506,-0.56533259,1.33120918,-0.43895102] + ] + ] + ); + + let g = y.exp().mean().backward(); + + assert_close_to_literal!( + g.get(&x), + [ + [ + [-0.24465635, -0.07769803, 0.07309557, 0.09593978], + [0.01824830, 0.06927966, -0.03252371, -0.23244604], + [-0.01082366, -0.01004899, -0.16018088, -0.32376897] + ], + [ + [0.11863519, -0.11774766, 0.16648589, 0.13729726], + [0.14436747, 0.10569359, 0.08925933, -0.20175745], + [0.01255126, -0.02915864, -0.19611855, 0.20987590] + ] + ] + ); + assert_close_to_literal!( + g.get(&w), + [ + [ + [[-0.06153677, -0.00425007], [0.25752968, 0.18442926]], + [[-0.05023063, 0.00882470], [-0.45103958, 0.01904970]], + [[0.08672050, 0.00499410], [-0.15565623, -0.10318264]] + ], + [ + [[-0.25463796, -0.01300730], [0.00675005, -0.00511352]], + [[0.13059653, -0.01341964], [0.25117764, -0.17709662]], + [[0.08071060, 0.07478670], [0.05806271, -0.11189780]] + ] + ] + ); +} + +#[test] +fn test_convtrans2d_stride_2() { + let device = TestDevice::default(); + let x = device + .tensor([ + [ + [-1.44135797, 0.23273671, 0.55838293, 1.09627271], + [0.05642751, 0.96609902, 0.24707083, -1.48412001], + [0.40077326, 1.16620362, 0.48770329, -1.01286852], + ], + [ + [0.32589695, -0.91695106, -0.13670059, 0.23979346], + [0.72553939, -0.38209674, 0.35545620, -0.37955058], + [-0.43962145, -0.69825196, -1.05400932, 1.22050178], + ], + ]) + .to_dtype::(); + let w = device + .tensor([ + [ + [[-0.46923456, -0.05046696], [0.30552489, 0.68838942]], + [[-0.45128539, 0.28081128], [-1.64214528, -0.67228431]], + [[0.82240576, 0.06074515], [-0.19825625, -0.28086993]], + ], + [ + [[-1.61626124, 0.44280547], [0.03486229, 0.80007958]], + [[-0.69502026, 0.22409980], [0.58573663, -0.36501792]], + [[1.57198501, 0.84847665], [0.52653229, -0.59273601]], + ], + ]) + .to_dtype::(); + let y = + (x.leaky_trace(), w.clone()).convtrans2d(Const::<2>, Const::<0>, Const::<1>, Const::<1>); + #[rustfmt::skip] + assert_close_to_literal!( + y, + [ + [ + [0.14960037,0.21704991,1.37282431,-0.41777647,-0.04106871,-0.08871166,-0.90197796,0.05085630], + [-0.42900923,-0.73147207,0.03913984,-0.57342035,0.16583419,0.27501357,0.34329835,0.94651639], + [-1.19913888,0.31842509,0.16424109,-0.21795061,-0.69044423,0.14492904,1.30985332,-0.09316804], + [0.04253398,0.61933333,0.28184652,0.35934454,0.08787831,0.45447421,-0.46666759,-1.32532310], + [0.52248645,-0.21489260,0.58133453,-0.36804456,1.47470713,-0.49133399,-1.49737680,0.59156126], + [0.10711999,-0.07584409,0.33196157,0.24414507,0.11226030,-0.50756156,-0.26690704,0.27925056] + ], + [ + [0.42395884,-0.33171612,0.53226888,-0.14013346,-0.15698038,0.12616566,-0.66139317,0.36158341], + [2.55780911,0.85004413,-0.91927934,0.17823833,-0.99701643,-0.32549390,-1.65978324,-0.82453585], + [-0.52972949,0.17843871,-0.17042139,0.18566369,-0.35854873,0.14903794,0.93355697,-0.50181484], + [0.33231282,-0.30277020,-1.81028295,-0.51002109,-0.19752248,-0.29584974,2.21482396,1.13629329], + [0.12468270,0.01402257,-0.04099140,0.17100500,0.51246446,-0.09925070,-0.39118069,-0.01091070], + [-0.91563028,-0.10896385,-2.32406759,-0.52914590,-1.41825151,0.05685703,2.37816978,0.23543060] + ], + [ + [-0.67307591,0.18896043,-1.25002933,-0.76387393,0.24432607,-0.08206820,1.27853274,0.27005240], + [0.45735350,0.21166326,-0.52894586,0.47814116,-0.18268019,-0.07580562,-0.09108391,-0.45004424], + [1.18694329,0.61903095,0.19387504,-0.26551431,0.76196432,0.31660464,-1.81719673,-0.41219291], + [0.37083280,-0.44590211,-0.39272144,-0.04486568,0.13817583,-0.28008646,0.09439043,0.64181799], + [-0.36148009,-0.34866351,-0.13854906,-0.52160925,-1.25579679,-0.86467665,1.08562160,0.97404039], + [-0.31093070,0.14801431,-0.59885937,0.08632754,-0.65166014,0.48776808,0.84344113,-0.43895102] + ] + ] + ); + + let g = y.exp().mean().backward(); + + assert_close_to_literal!( + g.get(&x), + [ + [ + [-0.16320729, -0.02408358, 0.00202971, 0.02914695], + [0.00702363, 0.00503576, 0.00480992, -0.13959591], + [-0.00830228, 0.00076824, -0.02246094, -0.10893497] + ], + [ + [0.04730723, -0.04577319, 0.01554872, 0.06038214], + [0.06361489, 0.01077521, 0.03939442, -0.01726136], + [-0.00589108, -0.00387334, -0.05110299, 0.10329218] + ] + ] + ); + assert_close_to_literal!( + g.get(&w), + [ + [ + [[0.00466055, -0.00407176], [0.02797192, 0.03677537]], + [[-0.01260271, 0.02238824], [-0.29015976, -0.03991098]], + [[0.02820409, -0.00682320], [-0.01703458, 0.00010475]] + ], + [ + [[-0.07678317, 0.00976532], [-0.01299408, 0.00706887]], + [[-0.02618737, -0.00559338], [0.09905507, -0.00913252]], + [[0.03891469, 0.02437293], [0.01698377, -0.02502952]] + ] + ] + ); +} + +#[test] +fn test_convtrans2d_padded() { + let device = TestDevice::default(); + let x = device + .tensor([ + [ + [-1.44135797, 0.23273671, 0.55838293, 1.09627271], + [0.05642751, 0.96609902, 0.24707083, -1.48412001], + [0.40077326, 1.16620362, 0.48770329, -1.01286852], + ], + [ + [0.32589695, -0.91695106, -0.13670059, 0.23979346], + [0.72553939, -0.38209674, 0.35545620, -0.37955058], + [-0.43962145, -0.69825196, -1.05400932, 1.22050178], + ], + ]) + .to_dtype::(); + let w = device + .tensor([ + [ + [[-0.46923456, -0.05046696], [0.30552489, 0.68838942]], + [[-0.45128539, 0.28081128], [-1.64214528, -0.67228431]], + [[0.82240576, 0.06074515], [-0.19825625, -0.28086993]], + ], + [ + [[-1.61626124, 0.44280547], [0.03486229, 0.80007958]], + [[-0.69502026, 0.22409980], [0.58573663, -0.36501792]], + [[1.57198501, 0.84847665], [0.52653229, -0.59273601]], + ], + ]) + .to_dtype::(); + let y = + (x.leaky_trace(), w.clone()).convtrans2d(Const::<1>, Const::<1>, Const::<1>, Const::<1>); + assert_close_to_literal!( + y, + [ + [ + [-0.20966601, -1.31598091, 2.07309437], + [1.26762176, 1.55388546, -2.00090408] + ], + [ + [-0.06121790, -0.99166316, -0.90268230], + [-2.14002204, -0.02407408, 1.42854285] + ], + [ + [0.49562341, 0.79191101, -1.66748166], + [-1.32583618, -1.68409586, 0.03524891] + ] + ] + ); + + let g = y.exp().mean().backward(); + + assert_close_to_literal!( + g.get(&x), + [ + [ + [-0.02973516, -0.12817474, 0.23232058, 0.09585897], + [0.14525947, 0.23093222, -0.10841238, -0.59855586], + [-0.00722821, -0.08082656, -0.07108198, -0.06080972] + ], + [ + [-0.03708792, 0.01189943, 0.41607130, 0.03411148], + [0.25580385, 0.35231194, 0.29753110, -0.54662442], + [0.10137358, -0.16306859, -0.34208044, -0.08278930] + ] + ] + ); + assert_close_to_literal!( + g.get(&w), + [ + [ + [[-0.25753322, 0.51525033], [0.74739242, 0.45198986]], + [[-0.17857774, 0.20734756], [-0.27595735, 0.05209281]], + [[0.06679222, 0.17222919], [0.03259417, -0.07203376]] + ], + [ + [[-0.58513254, -0.09418570], [0.07769967, -0.01389713]], + [[0.20000815, -0.24702756], [-0.11653619, 0.06147216]], + [[0.05383727, -0.05131300], [-0.12168837, -0.05695146]] + ] + ] + ); +} + +#[test] +fn test_convtrans2d_batched() { + let dev: TestDevice = Default::default(); + let x: Tensor, TestDtype, _> = dev.sample_normal(); + let w: Tensor, TestDtype, _> = dev.sample_normal(); + + let y: Tensor, _, _, _> = + (x.leaky_trace(), w.clone()).convtrans2d(Const::<3>, Const::<2>, Const::<1>, Const::<1>); + let y0 = y.retaped::(); + let grads0 = y.square().mean().backward(); + let x0 = grads0.get(&x); + let w0 = grads0.get(&w); + + let x = x + .broadcast::, _>() + .reshape::>(); + + let y: Tensor, _, _, _> = + (x.leaky_trace(), w.clone()).convtrans2d(Const::<3>, Const::<2>, Const::<1>, Const::<1>); + for i in 0..10 { + assert_close_to_tensor!(y0, y.retaped::().select(dev.tensor(i)), 1e-5); + } + + let grads = y.square().mean().backward(); + + assert_close_to_tensor!(w0, grads.get(&w)); + + let x_grad = grads.get(&x) * 10.0; + for i in 0..10 { + assert_close_to_tensor!(x0, x_grad.clone().select(dev.tensor(i))); + } +} + +#[test] +fn test_convtrans2d_grouped() { + let device = TestDevice::default(); + let x = device + .tensor([ + [ + [-1.44135797, 0.23273671, 0.55838293, 1.09627271], + [0.05642751, 0.96609902, 0.24707083, -1.48412001], + [0.40077326, 1.16620362, 0.48770329, -1.01286852], + ], + [ + [0.32589695, -0.91695106, -0.13670059, 0.23979346], + [0.72553939, -0.38209674, 0.35545620, -0.37955058], + [-0.43962145, -0.69825196, -1.05400932, 1.22050178], + ], + ]) + .to_dtype::(); + let w = device + .tensor([ + [ + [[-0.46923456, -0.05046696], [0.30552489, 0.68838942]], + [[-0.45128539, 0.28081128], [-1.64214528, -0.67228431]], + [[0.82240576, 0.06074515], [-0.19825625, -0.28086993]], + ], + [ + [[-1.61626124, 0.44280547], [0.03486229, 0.80007958]], + [[-0.69502026, 0.22409980], [0.58573663, -0.36501792]], + [[1.57198501, 0.84847665], [0.52653229, -0.59273601]], + ], + ]) + .to_dtype::(); + let y = + (x.leaky_trace(), w.clone()).convtrans2d(Const::<1>, Const::<0>, Const::<1>, Const::<2>); + #[rustfmt::skip] + assert_close_to_literal!( + y, + [ + [ + [0.67633498,-0.03646715,-0.27375808,-0.54258895,-0.05532555], + [-0.46684846, -1.37728357, 0.16612311, 1.40325499, 0.82956153], + [-0.17081666,-0.23343745,0.45283663,0.16730537,-0.97053605], + [0.12244620, 0.63219225, 0.95180768, 0.02627325, -0.69724798] + ], + [ + [0.65046382,-0.50978023,-0.18663496,-0.33793163,0.30784574], + [2.34145427,0.16667402,-0.91361910,-1.43648922,-1.15376461], + [-0.27352527,-2.03815913,-0.94782966,2.86508417,0.71332568], + [-0.65812790,-2.18450928,-1.58490002,1.33540201,0.68093562] + ], + [ + [-1.18538105, 0.10384849, 0.47335497, 0.93550003, 0.06659325], + [0.33216453, 1.15664566, 0.08580667, -1.57971644, -0.39806312], + [0.31841114, 0.77605361, 0.15159971, -0.57852203, 0.35531783], + [-0.07945581,-0.34377232,-0.42424178,0.06382634,0.28448433] + ], + [ + [-0.52673459,1.62634134,-0.18508710,-0.44810066,0.10618186], + [-1.16129971, 1.16761744, -1.48210406, 0.66983926, 0.02378678], + [0.73583704, 1.50105929, 1.10104883, -2.16820955, 0.23677418], + [-0.01532621,-0.37607479,-0.59540236,-0.80074185,0.97649854] + ], + [ + [-0.22650498,0.71033299,-0.11047887,-0.19729590,0.05373767], + [-0.31337482,-0.22789186,-0.07804406,0.53380698,-0.17258611], + [0.73052084, -0.10186276, 0.92375565, -1.43654132, 0.41205698], + [-0.25750238,-0.24852204,-0.36249739,1.09962487,-0.44550502] + ], + [ + [0.51230514, -1.16491735, -0.99290282, 0.26096445, 0.20345916], + [1.31213224,-0.66102237,0.70610434,-0.08776516,-0.46417403], + [-0.30905840,-2.10188961,-1.83569562,0.61377084,1.26054060], + [-0.23147489,-0.10707274,-0.14109090,1.26738286,-0.72343534] + ] + ] + ); + + let g = y.exp().mean().backward(); + + assert_close_to_literal!( + g.get(&x), + [ + [ + [-0.16682972, -0.01479191, 0.02486472, 0.03241257], + [-0.03953645, 0.01812358, -0.09578598, -0.26596320], + [0.00710792, 0.02639988, 0.02603059, -0.12393593] + ], + [ + [0.07047811, -0.07305276, 0.01832999, 0.03246089], + [0.09772546, -0.00640025, 0.03977848, -0.01866792], + [-0.00610692, -0.04190996, -0.05539178, 0.09744944] + ] + ] + ); + assert_close_to_literal!( + g.get(&w), + [ + [ + [[-0.05153870, 0.01042829], [0.05493409, 0.08134348]], + [[-0.14194693, 0.06882061], [-0.36182982, 0.00579517]], + [[0.08088457, 0.02259952], [0.01361402, -0.03323809]] + ], + [ + [[-0.10977639, 0.00816289], [-0.02106360, 0.03803319]], + [[-0.04609783, -0.00057952], [0.03170185, -0.03856671]], + [[0.04222549, 0.01465542], [0.02470277, -0.05392574]] + ] + ] + ); +} + +#[test] +fn test_convtrans2d_dilated() { + let device = TestDevice::default(); + let x = device + .tensor([ + [ + [-1.44135797, 0.23273671, 0.55838293, 1.09627271], + [0.05642751, 0.96609902, 0.24707083, -1.48412001], + [0.40077326, 1.16620362, 0.48770329, -1.01286852], + ], + [ + [0.32589695, -0.91695106, -0.13670059, 0.23979346], + [0.72553939, -0.38209674, 0.35545620, -0.37955058], + [-0.43962145, -0.69825196, -1.05400932, 1.22050178], + ], + ]) + .to_dtype::(); + let w = device + .tensor([ + [ + [[-0.46923456, -0.05046696], [0.30552489, 0.68838942]], + [[-0.45128539, 0.28081128], [-1.64214528, -0.67228431]], + [[0.82240576, 0.06074515], [-0.19825625, -0.28086993]], + ], + [ + [[-1.61626124, 0.44280547], [0.03486229, 0.80007958]], + [[-0.69502026, 0.22409980], [0.58573663, -0.36501792]], + [[1.57198501, 0.84847665], [0.52653229, -0.59273601]], + ], + ]) + .to_dtype::(); + let y = + (x.leaky_trace(), w.clone()).convtrans2d(Const::<1>, Const::<0>, Const::<2>, Const::<1>); + #[rustfmt::skip] + assert_close_to_literal!( + y, + [ + [ + [0.14960037,1.37282431,0.17598119,-1.31975436,-0.08871166,0.05085630], + [-1.19913888,0.16424109,-0.37201914,1.09190273,0.14492904,-0.09316804], + [0.09347722,0.62047440,0.69417661,-2.09554338,-0.21632043,1.53807759], + [0.04253398,0.28184652,0.70721161,-0.10732305,0.45447421,-1.32532310], + [0.10711999,0.33196157,0.03641622,-0.02276197,-0.50756156,0.27925056] + ], + [ + [0.42395884,0.53226888,-0.48869652,-0.80152661,0.12616566,0.36158341], + [-0.52972949,-0.17042139,-0.18011002,1.11922061,0.14903794,-0.50181484], + [2.68249178,-0.96027076,0.37951475,-1.70172071,-0.42474461,-0.83544654], + [0.33231282,-1.81028295,-0.50029266,1.70480287,-0.29584974,1.13629329], + [-0.91563028,-2.32406759,-1.52721536,1.84902382,0.05685703,0.23543060] + ], + [ + [-0.67307591,-1.25002933,0.43328649,0.51465881,-0.08206820,0.27005240], + [1.18694329,0.19387504,1.38099527,-2.08271098,0.31660464,-0.41219291], + [0.09587342,-0.66749489,-1.57547712,0.95106959,-0.94048226,0.52399611], + [0.37083280,-0.39272144,-0.30772626,0.04952475,-0.28008646,0.64181799], + [-0.31093070,-0.59885937,-0.50364584,0.92976868,0.48776808,-0.43895102] + ] + ] + ); + + let g = y.exp().mean().backward(); + + assert_close_to_literal!( + g.get(&x), + [ + [ + [-0.26850072, -0.03441733, -0.01182073, 0.03490706], + [0.01392877, -0.02877949, 0.03081299, -0.15472011], + [-0.06612058, -0.05083767, -0.02418187, -0.09756426] + ], + [ + [0.11071943, -0.07129460, 0.03023484, 0.08694579], + [0.11894555, 0.00127397, 0.08722699, -0.04324878], + [-0.08628573, -0.03130906, -0.03867241, 0.14042965] + ] + ] + ); + assert_close_to_literal!( + g.get(&w), + [ + [ + [[0.00630402, -0.01686253], [0.02441918, 0.04648526]], + [[0.02635447, 0.04999616], [-0.37556821, 0.07128908]], + [[0.03001181, 0.01555148], [-0.00549905, 0.04595280]] + ], + [ + [[-0.09109048, 0.03960687], [-0.02581818, 0.03304695]], + [[-0.10903731, -0.01481715], [0.11795966, -0.07065595]], + [[0.06082013, 0.02204099], [0.03207079, -0.05822099]] + ] + ] + ); +} diff --git a/src/tensor_ops/mod.rs b/src/tensor_ops/mod.rs index 4f35eb737..e5c9bd16a 100644 --- a/src/tensor_ops/mod.rs +++ b/src/tensor_ops/mod.rs @@ -273,7 +273,7 @@ pub use conv2d::TryConv2D; #[cfg(feature = "nightly")] mod convtrans2d; #[cfg(feature = "nightly")] -pub use convtrans2d::{ConvTransAlgebra, TryConvTrans2D, TryConvTrans2DTo}; +pub use convtrans2d::TryConvTrans2D; #[cfg(feature = "nightly")] mod pool2d;