Skip to content

Commit

Permalink
Adds 1d convolutions (#863)
Browse files Browse the repository at this point in the history
* conv1d - added 1d convolution (#807)

* conv1d - added 1d convolution

* conv1d - quick port of cuda code from conv2d -> conv1d

* conv1d - extra nightly to get through clippy

* conv1d - remove duplicates

* conv1d - clean up comments

* conv1d - remove cudnn

* fixes

---------

Co-authored-by: Corey Lowman <clowman1993@gmail.com>

* Simplify conv1d tests

* Fixing cpu kernel

* Updating cuda kernels

* Fixing bugs

* Reverting debugging changes

* Fixing grouped conv1d cuda kernels

---------

Co-authored-by: jcrist1 <jan.cristina@gmail.com>
  • Loading branch information
coreylowman and jcrist1 committed Sep 7, 2023
1 parent 47855ea commit 67f2568
Show file tree
Hide file tree
Showing 11 changed files with 1,489 additions and 13 deletions.
281 changes: 281 additions & 0 deletions src/nn/conv1d.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,281 @@
use num_traits::Float;
use rand_distr::uniform::SampleUniform;

use crate::{shapes::*, tensor::*, tensor_ops::*};

use super::*;

pub mod builder {
#[derive(Debug)]
pub struct Conv1D<
const IN_CHAN: usize,
const OUT_CHAN: usize,
const KERNEL_SIZE: usize,
const STRIDE: usize = 1,
const PADDING: usize = 0,
const DILATION: usize = 1,
const GROUPS: usize = 1,
>;
}

impl<
const I: usize,
const O: usize,
const K: usize,
const S: usize,
const P: usize,
const L: usize,
const G: usize,
E,
D,
> BuildOnDevice<D, E> for builder::Conv1D<I, O, K, S, P, L, G>
where
E: Dtype,
D: Device<E>,
Const<{ I / G }>: Sized,
Conv1D<I, O, K, S, P, L, G, E, D>: BuildModule<D, E>,
{
type Built = Conv1D<I, O, K, S, P, L, G, E, D>;
fn try_build_on_device(device: &D) -> Result<Self::Built, <D>::Err> {
Self::Built::try_build(device)
}
}

/// **Requires Nightly** Performs *unbiased* 1d convolutions on 2d and 3d images.
///
/// **Pytorch Equivalent**: `torch.nn.Conv1d(..., bias=False)`
///
/// Generics:
/// - `IN_CHAN`: The number of input channels in an image.
/// - `OUT_CHAN`: The number of channels in the output of the layer.
/// - `KERNEL_SIZE`: The size of the kernel applied to both width and height of the images.
/// - `STRIDE`: How far to move the kernel each step. Defaults to `1`
/// - `PADDING`: How much zero padding to add around the images. Defaults to `0`.
/// - `DILATION`: Controls the spacing between kernel points. Defaults to `1`.
/// - `GROUPS`: Controls the connections between inputs and outputs.
/// `IN_CHAN` and `OUT_CHAN` must both be divisible by `GROUPS`. For example,
///
/// See [conv animations](https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md) for helpful
/// visualization of all of these parameters.

#[derive(Debug, Clone)]
pub struct Conv1D<
const IN_CHAN: usize,
const OUT_CHAN: usize,
const KERNEL_SIZE: usize,
const STRIDE: usize,
const PADDING: usize,
const DILATION: usize,
const GROUPS: usize,
E: Dtype,
D: Storage<E>,
> where
Const<{ IN_CHAN / GROUPS }>: Sized,
{
pub weight: Tensor<Rank3<OUT_CHAN, { IN_CHAN / GROUPS }, KERNEL_SIZE>, E, D>,
}

impl<
const I: usize,
const O: usize,
const K: usize,
const S: usize,
const P: usize,
const L: usize,
const G: usize,
E,
D,
> TensorCollection<E, D> for Conv1D<I, O, K, S, P, L, G, E, D>
where
Const<{ I / G }>: Sized,
E: Dtype + Float + SampleUniform,
D: Device<E>,
{
type To<E2: Dtype, D2: Device<E2>> = Conv1D<I, O, K, S, P, L, G, E2, D2>;

fn iter_tensors<V: ModuleVisitor<Self, E, D>>(
visitor: &mut V,
) -> Result<Option<Self::To<V::E2, V::D2>>, V::Err> {
visitor.visit_fields(
Self::tensor(
"weight",
|s| &s.weight,
|s| &mut s.weight,
TensorOptions::reset_with(|t| {
let scale = E::from_f64(G as f64 / (I * K) as f64).unwrap();
let b = scale.sqrt();
t.try_fill_with_distr(rand_distr::Uniform::new(-b, b))
}),
),
|weight| Conv1D { weight },
)
}
}

impl<
const I: usize,
const O: usize,
const K: usize,
const S: usize,
const P: usize,
const L: usize,
const G: usize,
E,
D,
Img,
> Module<Img> for Conv1D<I, O, K, S, P, L, G, E, D>
where
Const<{ I / G }>: Sized,
E: Dtype,
D: Device<E>,
(Img, Tensor<Rank3<O, { I / G }, K>, E, D>): TryConv1D<Const<S>, Const<P>, Const<L>, Const<G>>,
{
type Output = <(Img, Tensor<Rank3<O, { I / G }, K>, E, D>) as TryConv1D<
Const<S>,
Const<P>,
Const<L>,
Const<G>,
>>::Convolved;
type Error = <(Img, Tensor<Rank3<O, { I / G }, K>, E, D>) as TryConv1D<
Const<S>,
Const<P>,
Const<L>,
Const<G>,
>>::Error;

fn try_forward(&self, x: Img) -> Result<Self::Output, Self::Error> {
(x, self.weight.clone()).try_conv1d(Const, Const, Const, Const)
}
}

impl<
const I: usize,
const O: usize,
const K: usize,
const S: usize,
const P: usize,
const L: usize,
const G: usize,
E: Dtype,
D: Storage<E>,
> NonMutableModule for Conv1D<I, O, K, S, P, L, G, E, D>
where
Const<{ I / G }>: Sized,
{
}

#[cfg(test)]
mod tests {
use crate::{
optim::*,
tensor::{AsArray, SampleTensor, ZerosTensor},
tests::*,
};

use super::{builder::Conv1D, *};

#[rustfmt::skip]
#[test]
fn test_forward_3d_sizes() {
let dev: TestDevice = Default::default();
let x = dev.zeros::<Rank2<3, 10>>();
let _: Tensor<Rank2<2, 8>, _, _, _> = dev.build_module::<Conv1D<3, 2, 3>, TestDtype>().forward(x.clone());
let _: Tensor<Rank2<4, 8>, _, _, _> = dev.build_module::<Conv1D<3, 4, 3>, TestDtype>().forward(x.clone());
let _: Tensor<Rank2<4, 9>, _, _, _> = dev.build_module::<Conv1D<3, 4, 2>, TestDtype>().forward(x.clone());
let _: Tensor<Rank2<4, 7>, _, _, _> = dev.build_module::<Conv1D<3, 4, 4>, TestDtype>().forward(x.clone());
let _: Tensor<Rank2<2, 4>, _, _, _> = dev.build_module::<Conv1D<3, 2, 3, 2>, TestDtype>().forward(x.clone());
let _: Tensor<Rank2<2, 3>, _, _, _> = dev.build_module::<Conv1D<3, 2, 3, 3>, TestDtype>().forward(x.clone());
let _: Tensor<Rank2<2, 10>, _, _, _> = dev.build_module::<Conv1D<3, 2, 3, 1, 1>, TestDtype>().forward(x.clone());
let _: Tensor<Rank2<2, 12>, _, _, _> = dev.build_module::<Conv1D<3, 2, 3, 1, 2>, TestDtype>().forward(x.clone());
let _: Tensor<Rank2<2, 6>, _, _, _> = dev.build_module::<Conv1D<3, 2, 3, 2, 2>, TestDtype>().forward(x.clone());
}

#[test]
fn test_grouped_forward_sizes() {
let dev: TestDevice = Default::default();

let x = dev.ones::<Rank2<16, 10>>();

let m = dev.build_module::<Conv1D<16, 32, 3, 1, 0, 1>, TestDtype>();
let _: Tensor<Rank3<32, 16, 3>, _, _> = m.weight;
let _: Tensor<Rank2<32, 8>, _, _> = m.forward(x.clone());
println!("1");

let m = dev.build_module::<Conv1D<16, 32, 3, 1, 0, 1, 2>, TestDtype>();
let _: Tensor<Rank3<32, 8, 3>, _, _> = m.weight;
let _: Tensor<Rank2<32, 8>, _, _> = m.forward(x.clone());
println!("2");

let m = dev.build_module::<Conv1D<16, 32, 3, 1, 0, 1, 4>, TestDtype>();
let _: Tensor<Rank3<32, 4, 3>, _, _> = m.weight;
let _: Tensor<Rank2<32, 8>, _, _> = m.forward(x.clone());
println!("3");

let m = dev.build_module::<Conv1D<16, 32, 3, 1, 0, 1, 8>, TestDtype>();
let _: Tensor<Rank3<32, 2, 3>, _, _> = m.weight;
let _: Tensor<Rank2<32, 8>, _, _> = m.forward(x.clone());
println!("4");

let m = dev.build_module::<Conv1D<16, 32, 3, 1, 0, 1, 16>, TestDtype>();
let _: Tensor<Rank3<32, 1, 3>, _, _> = m.weight;
let _: Tensor<Rank2<32, 8>, _, _> = m.forward(x);
println!("5");
}

#[rustfmt::skip]
#[test]
fn test_forward_4d_sizes() {
let dev: TestDevice = Default::default();
let x = dev.zeros::<Rank3<5, 3, 10>>();
let _: Tensor<Rank3<5, 2, 8>, _, _, _> = dev.build_module::<Conv1D<3, 2, 3>, TestDtype>().forward(x.clone());
let _: Tensor<Rank3<5, 4, 8>, _, _, _> = dev.build_module::<Conv1D<3, 4, 3>, TestDtype>().forward(x.clone());
let _: Tensor<Rank3<5, 4, 9>, _, _, _> = dev.build_module::<Conv1D<3, 4, 2>, TestDtype>().forward(x.clone());
let _: Tensor<Rank3<5, 4, 7>, _, _, _> = dev.build_module::<Conv1D<3, 4, 4>, TestDtype>().forward(x.clone());
let _: Tensor<Rank3<5, 2, 4>, _, _, _> = dev.build_module::<Conv1D<3, 2, 3, 2>, TestDtype>().forward(x.clone());
let _: Tensor<Rank3<5, 2, 3>, _, _, _> = dev.build_module::<Conv1D<3, 2, 3, 3>, TestDtype>().forward(x.clone());
let _: Tensor<Rank3<5, 2, 10>, _, _, _> = dev.build_module::<Conv1D<3, 2, 3, 1, 1>, TestDtype>().forward(x.clone());
let _: Tensor<Rank3<5, 2, 12>, _, _, _> = dev.build_module::<Conv1D<3, 2, 3, 1, 2>, TestDtype>().forward(x.clone());
let _: Tensor<Rank3<5, 2, 6>, _, _, _> = dev.build_module::<Conv1D<3, 2, 3, 2, 2>, TestDtype>().forward(x.clone());
}

#[test]
fn test_2_conv_sizes() {
let dev = Cpu::default();
type A = Conv1D<1, 2, 3>;
type B = Conv1D<2, 4, 3>;
let _: Tensor<Rank2<4, 6>, _, _> = dev
.build_module::<(A, B), TestDtype>()
.forward(dev.zeros::<Rank2<1, 10>>());
}

#[test]
fn test_3_conv_sizes() {
type A = Conv1D<1, 2, 3>;
type B = Conv1D<2, 4, 3>;
type C = Conv1D<4, 1, 1, 1, 1>;

let dev = Cpu::default();
let _: Tensor<Rank2<1, 8>, _, _> = dev
.build_module::<(A, B, C), TestDtype>()
.forward_mut(dev.zeros::<Rank2<1, 10>>());
}

#[test]
fn test_conv_with_optimizer() {
let dev: TestDevice = Default::default();

let mut m = dev.build_module::<Conv1D<2, 4, 3>, TestDtype>();

let weight_init = m.weight.clone();

let mut opt = Sgd::new(&m, Default::default());
let out = m.forward(dev.sample_normal::<Rank3<8, 2, 28>>().leaky_trace());
let g = out.square().mean().backward();

assert_ne!(g.get(&m.weight).array(), [[[TestDtype::zero(); 3]; 2]; 4]);

opt.update(&mut m, &g).expect("unused params");

assert_ne!(weight_init.array(), m.weight.array());
}
}
File renamed without changes.
8 changes: 5 additions & 3 deletions src/nn/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,9 @@ mod batchnorm1d;
mod batchnorm2d;
mod bias2d;
#[cfg(feature = "nightly")]
mod conv;
mod conv1d;
#[cfg(feature = "nightly")]
mod conv2d;
#[cfg(feature = "nightly")]
mod convtrans;
mod dropout;
Expand Down Expand Up @@ -243,7 +245,7 @@ pub mod modules {
pub use super::batchnorm2d::BatchNorm2D;
pub use super::bias2d::Bias2D;
#[cfg(feature = "nightly")]
pub use super::conv::Conv2D;
pub use super::conv2d::Conv2D;
#[cfg(feature = "nightly")]
pub use super::convtrans::ConvTrans2D;
pub use super::dropout::{Dropout, DropoutOneIn};
Expand Down Expand Up @@ -279,7 +281,7 @@ pub mod builders {
pub use super::batchnorm2d::builder::BatchNorm2D;
pub use super::bias2d::builder::Bias2D;
#[cfg(feature = "nightly")]
pub use super::conv::builder::Conv2D;
pub use super::conv2d::builder::Conv2D;
#[cfg(feature = "nightly")]
pub use super::convtrans::builder::ConvTrans2D;
pub use super::dropout::{Dropout, DropoutOneIn};
Expand Down
14 changes: 8 additions & 6 deletions src/tensor_ops/attention_reshape/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,12 +100,14 @@ mod tests {
let sequence_length = 1;
let past_length = 3;

let qkv: Tensor<(usize, Const<{ NUM_HEADS * HEAD_DIM * 3 }>), TestDtype, _> =
dev.zeros_like(&(sequence_length, Const)) + 1.0;
let past_key: Tensor<(Const<NUM_HEADS>, Const<HEAD_DIM>, usize), TestDtype, _> =
dev.zeros_like(&(Const, Const, past_length)) + 2.0;
let past_value: Tensor<(Const<NUM_HEADS>, usize, Const<HEAD_DIM>), TestDtype, _> =
dev.zeros_like(&(Const, past_length, Const)) + 3.0;
let qkv: Tensor<(usize, Const<{ NUM_HEADS * HEAD_DIM * 3 }>), f32, _> =
dev.ones_like(&(sequence_length, Const));
let past_key: Tensor<(Const<NUM_HEADS>, Const<HEAD_DIM>, usize), f32, _> =
dev.ones_like(&(Const, Const, past_length));
let past_key = past_key * 2.0;
let past_value: Tensor<(Const<NUM_HEADS>, usize, Const<HEAD_DIM>), f32, _> =
dev.ones_like(&(Const, past_length, Const));
let past_value = past_value * 3.0;

let (q, k, v) = dev.attention_reshape(&qkv, &past_key, &past_value);

Expand Down
Loading

0 comments on commit 67f2568

Please sign in to comment.