Skip to content

Commit

Permalink
Fixing weight shape for grouped Conv2D (#797)
Browse files Browse the repository at this point in the history
* Fixing weight shape for grouped Conv2D

* FIxing div impl
  • Loading branch information
coreylowman committed Jun 23, 2023
1 parent 4748054 commit 780b347
Show file tree
Hide file tree
Showing 4 changed files with 104 additions and 27 deletions.
52 changes: 43 additions & 9 deletions src/nn/conv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ impl<
where
E: Dtype,
D: Device<E>,
Const<{ I / G }>: Sized,
Conv2D<I, O, K, S, P, L, G, E, D>: BuildModule<D, E>,
{
type Built = Conv2D<I, O, K, S, P, L, G, E, D>;
Expand Down Expand Up @@ -62,6 +63,7 @@ where
///
/// See [conv animations](https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md) for helpful
/// visualization of all of these parameters.

#[derive(Debug, Clone)]
pub struct Conv2D<
const IN_CHAN: usize,
Expand All @@ -73,8 +75,10 @@ pub struct Conv2D<
const GROUPS: usize,
E: Dtype,
D: Storage<E>,
> {
pub weight: Tensor<Rank4<OUT_CHAN, IN_CHAN, KERNEL_SIZE, KERNEL_SIZE>, E, D>,
> where
Const<{ IN_CHAN / GROUPS }>: Sized,
{
pub weight: Tensor<Rank4<OUT_CHAN, { IN_CHAN / GROUPS }, KERNEL_SIZE, KERNEL_SIZE>, E, D>,
}

impl<
Expand All @@ -89,6 +93,7 @@ impl<
D,
> TensorCollection<E, D> for Conv2D<I, O, K, S, P, L, G, E, D>
where
Const<{ I / G }>: Sized,
E: Dtype + Float + SampleUniform,
D: Device<E>,
{
Expand All @@ -112,9 +117,8 @@ where
}
}

#[cfg(feature = "nightly")]
impl<
const C: usize,
const I: usize,
const O: usize,
const K: usize,
const S: usize,
Expand All @@ -124,19 +128,21 @@ impl<
E,
D,
Img,
> Module<Img> for Conv2D<C, O, K, S, P, L, G, E, D>
> Module<Img> for Conv2D<I, O, K, S, P, L, G, E, D>
where
Const<{ I / G }>: Sized,
E: Dtype,
D: Device<E>,
(Img, Tensor<Rank4<O, C, K, K>, E, D>): TryConv2D<Const<S>, Const<P>, Const<L>, Const<G>>,
(Img, Tensor<Rank4<O, { I / G }, K, K>, E, D>):
TryConv2D<Const<S>, Const<P>, Const<L>, Const<G>>,
{
type Output = <(Img, Tensor<Rank4<O, C, K, K>, E, D>) as TryConv2D<
type Output = <(Img, Tensor<Rank4<O, { I / G }, K, K>, E, D>) as TryConv2D<
Const<S>,
Const<P>,
Const<L>,
Const<G>,
>>::Convolved;
type Error = <(Img, Tensor<Rank4<O, C, K, K>, E, D>) as TryConv2D<
type Error = <(Img, Tensor<Rank4<O, { I / G }, K, K>, E, D>) as TryConv2D<
Const<S>,
Const<P>,
Const<L>,
Expand All @@ -159,10 +165,11 @@ impl<
E: Dtype,
D: Storage<E>,
> NonMutableModule for Conv2D<I, O, K, S, P, L, G, E, D>
where
Const<{ I / G }>: Sized,
{
}

#[cfg(feature = "nightly")]
#[cfg(test)]
mod tests {
use crate::{
Expand All @@ -189,6 +196,33 @@ mod tests {
let _: Tensor<Rank3<2, 6, 6>, _, _, _> = dev.build_module::<Conv2D<3, 2, 3, 2, 2>, TestDtype>().forward(x.clone());
}

#[test]
fn test_grouped_forward_sizes() {
let dev: TestDevice = Default::default();

let x = dev.zeros::<Rank3<16, 10, 10>>();

let m = dev.build_module::<Conv2D<16, 32, 3, 1, 0, 1, 1>, TestDtype>();
let _: Tensor<Rank4<32, 16, 3, 3>, _, _> = m.weight;
let _: Tensor<Rank3<32, 8, 8>, _, _> = m.forward(x.clone());

let m = dev.build_module::<Conv2D<16, 32, 3, 1, 0, 1, 2>, TestDtype>();
let _: Tensor<Rank4<32, 8, 3, 3>, _, _> = m.weight;
let _: Tensor<Rank3<32, 8, 8>, _, _> = m.forward(x.clone());

let m = dev.build_module::<Conv2D<16, 32, 3, 1, 0, 1, 4>, TestDtype>();
let _: Tensor<Rank4<32, 4, 3, 3>, _, _> = m.weight;
let _: Tensor<Rank3<32, 8, 8>, _, _> = m.forward(x.clone());

let m = dev.build_module::<Conv2D<16, 32, 3, 1, 0, 1, 8>, TestDtype>();
let _: Tensor<Rank4<32, 2, 3, 3>, _, _> = m.weight;
let _: Tensor<Rank3<32, 8, 8>, _, _> = m.forward(x.clone());

let m = dev.build_module::<Conv2D<16, 32, 3, 1, 0, 1, 16>, TestDtype>();
let _: Tensor<Rank4<32, 1, 3, 3>, _, _> = m.weight;
let _: Tensor<Rank3<32, 8, 8>, _, _> = m.forward(x);
}

#[rustfmt::skip]
#[test]
fn test_forward_4d_sizes() {
Expand Down
1 change: 1 addition & 0 deletions src/nn/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@ mod add_into;
mod batchnorm1d;
mod batchnorm2d;
mod bias2d;
#[cfg(feature = "nightly")]
mod conv;
mod convtrans;
mod dropout;
Expand Down
44 changes: 34 additions & 10 deletions src/shapes/shape.rs
Original file line number Diff line number Diff line change
Expand Up @@ -141,52 +141,76 @@ impl<const M: usize> ConstDim for Const<M> {

impl<const N: usize> core::ops::Add<Const<N>> for usize {
type Output = usize;
fn add(self, rhs: Const<N>) -> Self::Output {
self.size() + rhs.size()
fn add(self, _: Const<N>) -> Self::Output {
self.size() + N
}
}
impl<const N: usize> core::ops::Add<usize> for Const<N> {
type Output = usize;
fn add(self, rhs: usize) -> Self::Output {
self.size() + rhs.size()
N + rhs.size()
}
}

#[cfg(feature = "nightly")]
impl<const N: usize, const M: usize> core::ops::Add<Const<N>> for Const<M>
where
Const<{ N + M }>: Sized,
Const<{ M + N }>: Sized,
{
type Output = Const<{ N + M }>;
type Output = Const<{ M + N }>;
fn add(self, _: Const<N>) -> Self::Output {
Const
}
}

impl<const N: usize> core::ops::Mul<Const<N>> for usize {
type Output = usize;
fn mul(self, rhs: Const<N>) -> Self::Output {
self.size() * rhs.size()
fn mul(self, _: Const<N>) -> Self::Output {
self.size() * N
}
}
impl<const N: usize> core::ops::Mul<usize> for Const<N> {
type Output = usize;
fn mul(self, rhs: usize) -> Self::Output {
self.size() * rhs.size()
N * rhs.size()
}
}

#[cfg(feature = "nightly")]
impl<const N: usize, const M: usize> core::ops::Mul<Const<N>> for Const<M>
where
Const<{ N * M }>: Sized,
Const<{ M * N }>: Sized,
{
type Output = Const<{ N * M }>;
type Output = Const<{ M * N }>;
fn mul(self, _: Const<N>) -> Self::Output {
Const
}
}

impl<const N: usize> core::ops::Div<Const<N>> for usize {
type Output = usize;
fn div(self, _: Const<N>) -> Self::Output {
self.size() / N
}
}
impl<const N: usize> core::ops::Div<usize> for Const<N> {
type Output = usize;
fn div(self, rhs: usize) -> Self::Output {
N / rhs.size()
}
}

#[cfg(feature = "nightly")]
impl<const N: usize, const M: usize> core::ops::Div<Const<N>> for Const<M>
where
Const<{ M / N }>: Sized,
{
type Output = Const<{ M / N }>;
fn div(self, _: Const<N>) -> Self::Output {
Const
}
}

/// Represents either `[T; N]` or `Vec<T>`
pub trait Array<T>: IntoIterator<Item = T> {
type Dim: Dim;
Expand Down
34 changes: 26 additions & 8 deletions src/tensor_ops/conv2d/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -166,8 +166,17 @@ impl<Kernel: Dim, Stride: Dim, Padding: Dim, Dilation: Dim, Groups: Dim>
impl<InpChan, OutChan, Kernel, Stride, Padding, Dilation, Groups, H, W, E, D, T>
TryConv2D<Stride, Padding, Dilation, Groups>
for (
Tensor<(<InpChan as std::ops::Mul<Groups>>::Output, H, W), E, D, T>,
Tensor<(OutChan, InpChan, Kernel, Kernel), E, D>,
Tensor<(InpChan, H, W), E, D, T>,
Tensor<
(
OutChan,
<InpChan as std::ops::Div<Groups>>::Output,
Kernel,
Kernel,
),
E,
D,
>,
)
where
InpChan: Dim,
Expand All @@ -182,8 +191,8 @@ where
E: Dtype,
D: Conv2DKernel<E> + crate::tensor_ops::reshape_to::ReshapeKernel<E>,
T: Tape<E, D>,
InpChan: std::ops::Mul<Groups>,
<InpChan as std::ops::Mul<Groups>>::Output: Dim,
InpChan: std::ops::Div<Groups>,
<InpChan as std::ops::Div<Groups>>::Output: Dim,
(H, Kernel): TryConv2D<Stride, Padding, Dilation, Groups>,
(W, Kernel): TryConv2D<Stride, Padding, Dilation, Groups>,
<(H, Kernel) as TryConv2D<Stride, Padding, Dilation, Groups>>::Convolved: Dim,
Expand Down Expand Up @@ -220,8 +229,17 @@ where
impl<InpChan, OutChan, Kernel, Stride, Padding, Dilation, Groups, Batch, H, W, E, D, T>
TryConv2D<Stride, Padding, Dilation, Groups>
for (
Tensor<(Batch, <InpChan as std::ops::Mul<Groups>>::Output, H, W), E, D, T>,
Tensor<(OutChan, InpChan, Kernel, Kernel), E, D>,
Tensor<(Batch, InpChan, H, W), E, D, T>,
Tensor<
(
OutChan,
<InpChan as std::ops::Div<Groups>>::Output,
Kernel,
Kernel,
),
E,
D,
>,
)
where
InpChan: Dim,
Expand All @@ -237,8 +255,8 @@ where
E: Dtype,
D: Conv2DKernel<E>,
T: Tape<E, D>,
InpChan: std::ops::Mul<Groups>,
<InpChan as std::ops::Mul<Groups>>::Output: Dim,
InpChan: std::ops::Div<Groups>,
<InpChan as std::ops::Div<Groups>>::Output: Dim,
(H, Kernel): TryConv2D<Stride, Padding, Dilation, Groups>,
(W, Kernel): TryConv2D<Stride, Padding, Dilation, Groups>,
<(H, Kernel) as TryConv2D<Stride, Padding, Dilation, Groups>>::Convolved: Dim,
Expand Down

0 comments on commit 780b347

Please sign in to comment.