From c6ba3109e45997a46f6c09aecae85bf284f7c8f4 Mon Sep 17 00:00:00 2001 From: Corey Lowman Date: Fri, 23 Jun 2023 10:06:17 -0400 Subject: [PATCH] Fixing weight shape for grouped Conv2D --- src/nn/conv.rs | 52 +++++++++++++++++++++++++++++------- src/nn/mod.rs | 1 + src/shapes/shape.rs | 44 +++++++++++++++++++++++------- src/tensor_ops/conv2d/mod.rs | 34 +++++++++++++++++------ 4 files changed, 104 insertions(+), 27 deletions(-) diff --git a/src/nn/conv.rs b/src/nn/conv.rs index af21a6712..c53fe2d43 100644 --- a/src/nn/conv.rs +++ b/src/nn/conv.rs @@ -32,6 +32,7 @@ impl< where E: Dtype, D: Device, + Const<{ I / G }>: Sized, Conv2D: BuildModule, { type Built = Conv2D; @@ -62,6 +63,7 @@ where /// /// See [conv animations](https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md) for helpful /// visualization of all of these parameters. + #[derive(Debug, Clone)] pub struct Conv2D< const IN_CHAN: usize, @@ -73,8 +75,10 @@ pub struct Conv2D< const GROUPS: usize, E: Dtype, D: Storage, -> { - pub weight: Tensor, E, D>, +> where + Const<{ IN_CHAN / GROUPS }>: Sized, +{ + pub weight: Tensor, E, D>, } impl< @@ -89,6 +93,7 @@ impl< D, > TensorCollection for Conv2D where + Const<{ I / G }>: Sized, E: Dtype + Float + SampleUniform, D: Device, { @@ -112,9 +117,8 @@ where } } -#[cfg(feature = "nightly")] impl< - const C: usize, + const I: usize, const O: usize, const K: usize, const S: usize, @@ -124,19 +128,21 @@ impl< E, D, Img, - > Module for Conv2D + > Module for Conv2D where + Const<{ I / G }>: Sized, E: Dtype, D: Device, - (Img, Tensor, E, D>): TryConv2D, Const

, Const, Const>, + (Img, Tensor, E, D>): + TryConv2D, Const

, Const, Const>, { - type Output = <(Img, Tensor, E, D>) as TryConv2D< + type Output = <(Img, Tensor, E, D>) as TryConv2D< Const, Const

, Const, Const, >>::Convolved; - type Error = <(Img, Tensor, E, D>) as TryConv2D< + type Error = <(Img, Tensor, E, D>) as TryConv2D< Const, Const

, Const, @@ -159,10 +165,11 @@ impl< E: Dtype, D: Storage, > NonMutableModule for Conv2D +where + Const<{ I / G }>: Sized, { } -#[cfg(feature = "nightly")] #[cfg(test)] mod tests { use crate::{ @@ -189,6 +196,33 @@ mod tests { let _: Tensor, _, _, _> = dev.build_module::, TestDtype>().forward(x.clone()); } + #[test] + fn test_grouped_forward_sizes() { + let dev: TestDevice = Default::default(); + + let x = dev.zeros::>(); + + let m = dev.build_module::, TestDtype>(); + let _: Tensor, _, _> = m.weight; + let _: Tensor, _, _> = m.forward(x.clone()); + + let m = dev.build_module::, TestDtype>(); + let _: Tensor, _, _> = m.weight; + let _: Tensor, _, _> = m.forward(x.clone()); + + let m = dev.build_module::, TestDtype>(); + let _: Tensor, _, _> = m.weight; + let _: Tensor, _, _> = m.forward(x.clone()); + + let m = dev.build_module::, TestDtype>(); + let _: Tensor, _, _> = m.weight; + let _: Tensor, _, _> = m.forward(x.clone()); + + let m = dev.build_module::, TestDtype>(); + let _: Tensor, _, _> = m.weight; + let _: Tensor, _, _> = m.forward(x); + } + #[rustfmt::skip] #[test] fn test_forward_4d_sizes() { diff --git a/src/nn/mod.rs b/src/nn/mod.rs index 33e4ebdc5..ddf87b9ed 100644 --- a/src/nn/mod.rs +++ b/src/nn/mod.rs @@ -188,6 +188,7 @@ mod add_into; mod batchnorm1d; mod batchnorm2d; mod bias2d; +#[cfg(feature = "nightly")] mod conv; mod convtrans; mod dropout; diff --git a/src/shapes/shape.rs b/src/shapes/shape.rs index e15915b65..dce825cab 100644 --- a/src/shapes/shape.rs +++ b/src/shapes/shape.rs @@ -141,23 +141,23 @@ impl ConstDim for Const { impl core::ops::Add> for usize { type Output = usize; - fn add(self, rhs: Const) -> Self::Output { - self.size() + rhs.size() + fn add(self, _: Const) -> Self::Output { + self.size() + N } } impl core::ops::Add for Const { type Output = usize; fn add(self, rhs: usize) -> Self::Output { - self.size() + rhs.size() + N + rhs.size() } } #[cfg(feature = "nightly")] impl core::ops::Add> for Const where - Const<{ N + M }>: Sized, + Const<{ M + N }>: Sized, { - type Output = Const<{ N + M }>; + type Output = Const<{ M + N }>; fn add(self, _: Const) -> Self::Output { Const } @@ -165,28 +165,52 @@ where impl core::ops::Mul> for usize { type Output = usize; - fn mul(self, rhs: Const) -> Self::Output { - self.size() * rhs.size() + fn mul(self, _: Const) -> Self::Output { + self.size() * N } } impl core::ops::Mul for Const { type Output = usize; fn mul(self, rhs: usize) -> Self::Output { - self.size() * rhs.size() + N * rhs.size() } } #[cfg(feature = "nightly")] impl core::ops::Mul> for Const where - Const<{ N * M }>: Sized, + Const<{ M * N }>: Sized, { - type Output = Const<{ N * M }>; + type Output = Const<{ M * N }>; fn mul(self, _: Const) -> Self::Output { Const } } +impl core::ops::Div> for usize { + type Output = usize; + fn div(self, _: Const) -> Self::Output { + self.size() / N + } +} +impl core::ops::Div for Const { + type Output = usize; + fn div(self, rhs: usize) -> Self::Output { + N * rhs.size() + } +} + +#[cfg(feature = "nightly")] +impl core::ops::Div> for Const +where + Const<{ M / N }>: Sized, +{ + type Output = Const<{ M / N }>; + fn div(self, _: Const) -> Self::Output { + Const + } +} + /// Represents either `[T; N]` or `Vec` pub trait Array: IntoIterator { type Dim: Dim; diff --git a/src/tensor_ops/conv2d/mod.rs b/src/tensor_ops/conv2d/mod.rs index 8f00c99bf..2c34b8b9c 100644 --- a/src/tensor_ops/conv2d/mod.rs +++ b/src/tensor_ops/conv2d/mod.rs @@ -166,8 +166,17 @@ impl impl TryConv2D for ( - Tensor<(>::Output, H, W), E, D, T>, - Tensor<(OutChan, InpChan, Kernel, Kernel), E, D>, + Tensor<(InpChan, H, W), E, D, T>, + Tensor< + ( + OutChan, + >::Output, + Kernel, + Kernel, + ), + E, + D, + >, ) where InpChan: Dim, @@ -182,8 +191,8 @@ where E: Dtype, D: Conv2DKernel + crate::tensor_ops::reshape_to::ReshapeKernel, T: Tape, - InpChan: std::ops::Mul, - >::Output: Dim, + InpChan: std::ops::Div, + >::Output: Dim, (H, Kernel): TryConv2D, (W, Kernel): TryConv2D, <(H, Kernel) as TryConv2D>::Convolved: Dim, @@ -220,8 +229,17 @@ where impl TryConv2D for ( - Tensor<(Batch, >::Output, H, W), E, D, T>, - Tensor<(OutChan, InpChan, Kernel, Kernel), E, D>, + Tensor<(Batch, InpChan, H, W), E, D, T>, + Tensor< + ( + OutChan, + >::Output, + Kernel, + Kernel, + ), + E, + D, + >, ) where InpChan: Dim, @@ -237,8 +255,8 @@ where E: Dtype, D: Conv2DKernel, T: Tape, - InpChan: std::ops::Mul, - >::Output: Dim, + InpChan: std::ops::Div, + >::Output: Dim, (H, Kernel): TryConv2D, (W, Kernel): TryConv2D, <(H, Kernel) as TryConv2D>::Convolved: Dim,