Skip to content

Commit

Permalink
Adding dilation/groups to conv transpose
Browse files Browse the repository at this point in the history
  • Loading branch information
coreylowman committed May 12, 2023
1 parent 37a711a commit 3958d6e
Show file tree
Hide file tree
Showing 4 changed files with 349 additions and 210 deletions.
88 changes: 71 additions & 17 deletions src/nn/convtrans.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,28 @@ pub mod builder {
const KERNEL_SIZE: usize,
const STRIDE: usize = 1,
const PADDING: usize = 0,
const DILATION: usize = 1,
const GROUPS: usize = 1,
>;
}

impl<const I: usize, const O: usize, const K: usize, const S: usize, const P: usize, E, D>
BuildOnDevice<D, E> for builder::ConvTrans2D<I, O, K, S, P>
impl<
const I: usize,
const O: usize,
const K: usize,
const S: usize,
const P: usize,
const L: usize,
const G: usize,
E,
D,
> BuildOnDevice<D, E> for builder::ConvTrans2D<I, O, K, S, P, L, G>
where
E: Dtype,
D: Device<E>,
ConvTrans2D<I, O, K, S, P, E, D>: BuildModule<D, E>,
ConvTrans2D<I, O, K, S, P, L, G, E, D>: BuildModule<D, E>,
{
type Built = ConvTrans2D<I, O, K, S, P, E, D>;
type Built = ConvTrans2D<I, O, K, S, P, L, G, E, D>;
fn try_build_on_device(device: &D) -> Result<Self::Built, <D>::Err> {
Self::Built::try_build(device)
}
Expand All @@ -45,26 +56,40 @@ where
/// - `KERNEL_SIZE`: The size of the kernel applied to both width and height of the images.
/// - `STRIDE`: How far to move the kernel each step. Defaults to `1`
/// - `PADDING`: How much zero padding to add around the images. Defaults to `0`.
/// - `DILATION`: Controls the spacing between kernel points. Defaults to `1`.
/// - `GROUPS`: Controls the connections between inputs and outputs.
/// `IN_CHAN` and `OUT_CHAN` must both be divisible by `GROUPS`. For example,
#[derive(Debug, Clone)]
pub struct ConvTrans2D<
const IN_CHAN: usize,
const OUT_CHAN: usize,
const KERNEL_SIZE: usize,
const STRIDE: usize,
const PADDING: usize,
const DILATION: usize,
const GROUPS: usize,
E: Dtype,
D: Storage<E>,
> {
pub weight: Tensor<Rank4<OUT_CHAN, IN_CHAN, KERNEL_SIZE, KERNEL_SIZE>, E, D>,
}

impl<const I: usize, const O: usize, const K: usize, const S: usize, const P: usize, E, D>
TensorCollection<E, D> for ConvTrans2D<I, O, K, S, P, E, D>
impl<
const I: usize,
const O: usize,
const K: usize,
const S: usize,
const P: usize,
const L: usize,
const G: usize,
E,
D,
> TensorCollection<E, D> for ConvTrans2D<I, O, K, S, P, L, G, E, D>
where
E: Dtype + Float + SampleUniform,
D: Device<E>,
{
type To<E2: Dtype, D2: Device<E2>> = ConvTrans2D<I, O, K, S, P, E2, D2>;
type To<E2: Dtype, D2: Device<E2>> = ConvTrans2D<I, O, K, S, P, L, G, E2, D2>;

fn iter_tensors<V: ModuleVisitor<Self, E, D>>(
visitor: &mut V,
Expand All @@ -85,23 +110,52 @@ where
}

#[cfg(feature = "nightly")]
impl<const C: usize, const O: usize, const K: usize, const S: usize, const P: usize, E, D, Img>
Module<Img> for ConvTrans2D<C, O, K, S, P, E, D>
impl<
const C: usize,
const O: usize,
const K: usize,
const S: usize,
const P: usize,
const L: usize,
const G: usize,
E,
D,
Img,
> Module<Img> for ConvTrans2D<C, O, K, S, P, L, G, E, D>
where
E: Dtype,
D: Device<E>,
Img: TryConvTrans2DTo<Tensor<Rank4<O, C, K, K>, E, D>, S, P> + HasErr<Err = D::Err>,
(Img, Tensor<Rank4<O, C, K, K>, E, D>): TryConvTrans2D<Const<S>, Const<P>, Const<L>, Const<G>>,
{
type Output = Img::Output;
type Error = D::Err;

fn try_forward(&self, x: Img) -> Result<Self::Output, D::Err> {
x.try_convtrans2d_to(self.weight.clone())
type Output = <(Img, Tensor<Rank4<O, C, K, K>, E, D>) as TryConvTrans2D<
Const<S>,
Const<P>,
Const<L>,
Const<G>,
>>::Convolved;
type Error = <(Img, Tensor<Rank4<O, C, K, K>, E, D>) as TryConvTrans2D<
Const<S>,
Const<P>,
Const<L>,
Const<G>,
>>::Error;

fn try_forward(&self, x: Img) -> Result<Self::Output, Self::Error> {
(x, self.weight.clone()).try_convtrans2d(Const, Const, Const, Const)
}
}

impl<const I: usize, const O: usize, const K: usize, const S: usize, const P: usize, E, D>
NonMutableModule for ConvTrans2D<I, O, K, S, P, E, D>
impl<
const I: usize,
const O: usize,
const K: usize,
const S: usize,
const P: usize,
const L: usize,
const G: usize,
E,
D,
> NonMutableModule for ConvTrans2D<I, O, K, S, P, L, G, E, D>
where
E: Dtype,
D: Storage<E>,
Expand Down
134 changes: 81 additions & 53 deletions src/tensor_ops/convtrans2d/cpu_kernel.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::prelude::Tensorlike;
use crate::shapes::{Dtype, Shape};
use crate::tensor::{cpu::*, Tensor};
use crate::tensor::{cpu::*, Tensor, ZerosTensor};
use crate::tensor_ops::matmul::cpu_kernel::MatMulImpl;

use std::sync::Arc;
Expand All @@ -10,9 +10,9 @@ use super::{ConvTrans2DKernel, ConvTrans2DOp};
impl ConvTrans2DOp {
#[inline(always)]
fn unfold_idx(&self, [k1, k2, y, x]: [usize; 4]) -> Option<[usize; 2]> {
(y * self.stride + k1)
(y * self.stride + self.dilation * k1)
.checked_sub(self.padding)
.zip((x * self.stride + k2).checked_sub(self.padding))
.zip((x * self.stride + self.dilation * k2).checked_sub(self.padding))
.filter(|&(oh, ow)| oh < self.h_out && ow < self.w_out)
.map(|(oh, ow)| [oh, ow])
}
Expand All @@ -33,17 +33,17 @@ impl Cpu {
{
{
let mut i = 0;
for c in 0..op.chan_in {
for c in 0..(op.groups * op.chan_in) {
for k1 in 0..op.kernel {
for k2 in 0..op.kernel {
for oh in 0..op.h_out {
for ow in 0..op.w_out {
i += 1;
let mut y = oh + op.padding;
if y < k1 {
if y < op.dilation * k1 {
continue;
}
y -= k1;
y -= op.dilation * k1;
if y % op.stride != 0 {
continue;
}
Expand All @@ -53,10 +53,10 @@ impl Cpu {
}

let mut x = ow + op.padding;
if x < k2 {
if x < op.dilation * k2 {
continue;
}
x -= k2;
x -= op.dilation * k2;
if x % op.stride != 0 {
continue;
}
Expand All @@ -75,20 +75,24 @@ impl Cpu {
}
}

// (O, C * K * K) * (C * K * K, OH * OW) = (O, OH * OW)
let m = op.chan_out;
// LHS: (G, O/G, C*K*K)
// RHS: (G, C*K*K, OH * OW)
// OUT: (G, O/G, OH * OW)
let m = op.chan_out / op.groups;
let k = op.chan_in * op.kernel * op.kernel;
let n = op.w_out * op.h_out;
Self::matmul(
(m, k, n),
false,
filters.as_ptr(),
[k, 1],
buf.as_ptr(),
[n, 1],
out.as_mut_ptr(),
[n, 1],
);
for g in 0..op.groups {
Self::matmul(
(m, k, n),
false,
filters[g * m * k..].as_ptr(),
[k, 1],
buf[g * k * n..].as_ptr(),
[n, 1],
out[g * m * n..].as_mut_ptr(),
[n, 1],
);
}
Ok(())
}

Expand Down Expand Up @@ -127,39 +131,45 @@ impl Cpu {
}

{
// img_g += filters^T * unfold(grad_out)
// (C, H * W) += (C, O * K * K) * (O * K * K, H * W)
// LHS: (G, C, O/G * K * K)
// RHS: (G, O/G * K * K, H * W)
// OUT: (G, C, H * W)
let m = op.chan_in;
let k = op.chan_out * op.kernel * op.kernel;
let k = (op.chan_out / op.groups) * op.kernel * op.kernel;
let n = op.h_in * op.w_in;
Self::matmul(
(m, k, n),
true,
filters_tr.as_ptr(),
[k, 1],
buf.as_ptr(),
[n, 1],
grad_img.as_mut_ptr(),
[n, 1],
);
for g in 0..op.groups {
Self::matmul(
(m, k, n),
true,
filters_tr[g * m * k..].as_ptr(),
[k, 1],
buf[g * k * n..].as_ptr(),
[n, 1],
grad_img[g * m * n..].as_mut_ptr(),
[n, 1],
);
}
}

{
// weight_g^T += img * patches^T
// (C, O * K * K) += (C, H * W) * (H * W, O * K * K)
// LHS: (G, C, H * W)
// RHS: (G, H * W, O/G * K * K)
// OUT: (G, C, O/G * K * K)
let m = op.chan_in;
let k = op.h_in * op.w_in;
let n = op.chan_out * op.kernel * op.kernel;
Self::matmul(
(m, k, n),
true,
img.as_ptr(),
[k, 1],
buf.as_ptr(),
[1, k],
grad_filters_tr.as_mut_ptr(),
[n, 1],
);
let n = (op.chan_out / op.groups) * op.kernel * op.kernel;
for g in 0..op.groups {
Self::matmul(
(m, k, n),
true,
img[g * m * k..].as_ptr(),
[k, 1],
buf[g * k * n..].as_ptr(),
[1, k],
grad_filters_tr[g * m * n..].as_mut_ptr(),
[n, 1],
);
}
}
Ok(())
}
Expand All @@ -169,14 +179,25 @@ impl<E: Dtype> ConvTrans2DKernel<E> for Cpu
where
Self: MatMulImpl<E>,
{
fn alloc<S: Shape>(&self, s: S) -> Result<Tensor<S, E, Self>, Self::Err> {
self.try_zeros_like(&s)
}

fn forward<L: Shape, R: Shape, O: Shape>(
&self,
op: ConvTrans2DOp,
lhs: &Tensor<L, E, Self>,
rhs: &Tensor<R, E, Self>,
out: &mut Tensor<O, E, Self>,
) -> Result<(), Self::Err> {
let mut patches = self.try_alloc_zeros::<E>(op.inp_patches_shape().num_elements())?;
let patches = (
op.groups * op.chan_in,
op.kernel,
op.kernel,
op.h_out,
op.w_out,
);
let mut patches = self.try_alloc_zeros::<E>(patches.num_elements())?;
let [lstride, ostride] = match L::NUM_DIMS {
3 => [0; 2],
4 => [lhs.strides[0], out.strides[0]],
Expand Down Expand Up @@ -207,17 +228,24 @@ where
out: &impl Tensorlike<O, E, Self>,
grad_out: &Self::Vec,
) -> Result<(), Self::Err> {
let f_tr_shape = op.filters_tr_shape();
let mut patches = self.try_alloc_zeros::<E>(op.out_patches_shape().num_elements())?;
let f_tr_shape = [
op.groups,
op.chan_in,
op.chan_out / op.groups,
op.kernel,
op.kernel,
];
let patches_shape = [op.chan_out, op.kernel, op.kernel, op.h_in, op.w_in];
let mut patches = self.try_alloc_zeros::<E>(patches_shape.num_elements())?;
let mut f1023 = self.try_alloc_zeros::<E>(f_tr_shape.num_elements())?;
let mut grad_f1023 = self.try_alloc_zeros::<E>(f_tr_shape.num_elements())?;

{
// transpose filters in f1023
let buf = rhs.data.as_ref();
let mut f_idx = NdIndex::new(f_tr_shape, f_tr_shape.strides());
while let Some((i, [c, o, k1, k2])) = f_idx.next_with_idx() {
let idx = o * rhs.strides[0]
while let Some((i, [g, c, o, k1, k2])) = f_idx.next_with_idx() {
let idx = (g * (op.chan_out / op.groups) + o) * rhs.strides[0]
+ c * rhs.strides[1]
+ k1 * rhs.strides[2]
+ k2 * rhs.strides[3];
Expand Down Expand Up @@ -247,8 +275,8 @@ where
{
// untranspose filters
let mut f_idx = NdIndex::new(f_tr_shape, f_tr_shape.strides());
while let Some((i, [c, o, k1, k2])) = f_idx.next_with_idx() {
let idx = o * rhs.strides[0]
while let Some((i, [g, c, o, k1, k2])) = f_idx.next_with_idx() {
let idx = (g * (op.chan_out / op.groups) + o) * rhs.strides[0]
+ c * rhs.strides[1]
+ k1 * rhs.strides[2]
+ k2 * rhs.strides[3];
Expand Down
Loading

0 comments on commit 3958d6e

Please sign in to comment.