From cab64bd093f48efb98fb5b1281163285d398ff64 Mon Sep 17 00:00:00 2001 From: jcrist1 Date: Thu, 7 Sep 2023 19:26:18 +0200 Subject: [PATCH] conv1d - added 1d convolution (#807) * conv1d - added 1d convolution * conv1d - quick port of cuda code from conv2d -> conv1d * conv1d - extra nightly to get through clippy * conv1d - remove duplicates * conv1d - clean up comments * conv1d - remove cudnn * fixes * conv1 - some fixes * trying something that doesn't work * debug * more debug * more debug * more debug * more debug * more debug * more debug * more debug * more debug * more debug * more debug * more debug * conv1d - more debugging --------- Co-authored-by: Corey Lowman --- src/nn/conv1d.rs | 280 +++++++++++++++++++++++++ src/nn/{conv.rs => conv2d.rs} | 0 src/nn/mod.rs | 8 +- src/tensor_ops/conv1d/conv1d.cu | 183 +++++++++++++++++ src/tensor_ops/conv1d/cpu_kernel.rs | 262 ++++++++++++++++++++++++ src/tensor_ops/conv1d/cuda_kernel.rs | 289 ++++++++++++++++++++++++++ src/tensor_ops/conv1d/mod.rs | 291 ++++++++++++++++++++++++++ src/tensor_ops/conv1d/tests.rs | 294 +++++++++++++++++++++++++++ src/tensor_ops/conv2d/cpu_kernel.rs | 8 +- src/tensor_ops/conv2d/cuda_kernel.rs | 4 +- src/tensor_ops/matmul/cuda_kernel.rs | 20 ++ src/tensor_ops/mod.rs | 5 + 12 files changed, 1635 insertions(+), 9 deletions(-) create mode 100644 src/nn/conv1d.rs rename src/nn/{conv.rs => conv2d.rs} (100%) create mode 100644 src/tensor_ops/conv1d/conv1d.cu create mode 100644 src/tensor_ops/conv1d/cpu_kernel.rs create mode 100644 src/tensor_ops/conv1d/cuda_kernel.rs create mode 100644 src/tensor_ops/conv1d/mod.rs create mode 100644 src/tensor_ops/conv1d/tests.rs diff --git a/src/nn/conv1d.rs b/src/nn/conv1d.rs new file mode 100644 index 000000000..eb1d545cc --- /dev/null +++ b/src/nn/conv1d.rs @@ -0,0 +1,280 @@ +use num_traits::Float; +use rand_distr::uniform::SampleUniform; + +use crate::{shapes::*, tensor::*, tensor_ops::*}; + +use super::*; + +pub mod builder { + #[derive(Debug)] + pub struct Conv1D< + const IN_CHAN: usize, + const OUT_CHAN: usize, + const KERNEL_SIZE: usize, + const STRIDE: usize = 1, + const PADDING: usize = 0, + const DILATION: usize = 1, + const GROUPS: usize = 1, + >; +} + +impl< + const I: usize, + const O: usize, + const K: usize, + const S: usize, + const P: usize, + const L: usize, + const G: usize, + E, + D, + > BuildOnDevice for builder::Conv1D +where + E: Dtype, + D: Device, + Const<{ I / G }>: Sized, + Conv1D: BuildModule, +{ + type Built = Conv1D; + fn try_build_on_device(device: &D) -> Result::Err> { + Self::Built::try_build(device) + } +} + +/// **Requires Nightly** Performs *unbiased* 1d convolutions on 2d and 3d images. +/// +/// **Pytorch Equivalent**: `torch.nn.Conv1d(..., bias=False)` +/// +/// Generics: +/// - `IN_CHAN`: The number of input channels in an image. +/// - `OUT_CHAN`: The number of channels in the output of the layer. +/// - `KERNEL_SIZE`: The size of the kernel applied to both width and height of the images. +/// - `STRIDE`: How far to move the kernel each step. Defaults to `1` +/// - `PADDING`: How much zero padding to add around the images. Defaults to `0`. +/// - `DILATION`: Controls the spacing between kernel points. Defaults to `1`. +/// - `GROUPS`: Controls the connections between inputs and outputs. +/// `IN_CHAN` and `OUT_CHAN` must both be divisible by `GROUPS`. For example, +/// +/// See [conv animations](https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md) for helpful +/// visualization of all of these parameters. + +#[derive(Debug, Clone)] +pub struct Conv1D< + const IN_CHAN: usize, + const OUT_CHAN: usize, + const KERNEL_SIZE: usize, + const STRIDE: usize, + const PADDING: usize, + const DILATION: usize, + const GROUPS: usize, + E: Dtype, + D: Storage, +> where + Const<{ IN_CHAN / GROUPS }>: Sized, +{ + pub weight: Tensor, E, D>, +} + +impl< + const I: usize, + const O: usize, + const K: usize, + const S: usize, + const P: usize, + const L: usize, + const G: usize, + E, + D, + > TensorCollection for Conv1D +where + Const<{ I / G }>: Sized, + E: Dtype + Float + SampleUniform, + D: Device, +{ + type To> = Conv1D; + + fn iter_tensors>( + visitor: &mut V, + ) -> Result>, V::Err> { + visitor.visit_fields( + Self::tensor( + "weight", + |s| &s.weight, + |s| &mut s.weight, + TensorOptions::reset_with(|t| { + let b = E::ONE / E::from_usize(I * K).unwrap().sqrt(); + t.try_fill_with_distr(rand_distr::Uniform::new(-b, b)) + }), + ), + |weight| Conv1D { weight }, + ) + } +} + +impl< + const I: usize, + const O: usize, + const K: usize, + const S: usize, + const P: usize, + const L: usize, + const G: usize, + E, + D, + Img, + > Module for Conv1D +where + Const<{ I / G }>: Sized, + E: Dtype, + D: Device, + (Img, Tensor, E, D>): TryConv1D, Const

, Const, Const>, +{ + type Output = <(Img, Tensor, E, D>) as TryConv1D< + Const, + Const

, + Const, + Const, + >>::Convolved; + type Error = <(Img, Tensor, E, D>) as TryConv1D< + Const, + Const

, + Const, + Const, + >>::Error; + + fn try_forward(&self, x: Img) -> Result { + (x, self.weight.clone()).try_conv1d(Const, Const, Const, Const) + } +} + +impl< + const I: usize, + const O: usize, + const K: usize, + const S: usize, + const P: usize, + const L: usize, + const G: usize, + E: Dtype, + D: Storage, + > NonMutableModule for Conv1D +where + Const<{ I / G }>: Sized, +{ +} + +#[cfg(test)] +mod tests { + use crate::{ + optim::*, + tensor::{AsArray, SampleTensor, ZerosTensor}, + tests::*, + }; + + use super::{builder::Conv1D, *}; + + #[rustfmt::skip] + #[test] + fn test_forward_3d_sizes() { + let dev: TestDevice = Default::default(); + let x = dev.zeros::>(); + let _: Tensor, _, _, _> = dev.build_module::, TestDtype>().forward(x.clone()); + let _: Tensor, _, _, _> = dev.build_module::, TestDtype>().forward(x.clone()); + let _: Tensor, _, _, _> = dev.build_module::, TestDtype>().forward(x.clone()); + let _: Tensor, _, _, _> = dev.build_module::, TestDtype>().forward(x.clone()); + let _: Tensor, _, _, _> = dev.build_module::, TestDtype>().forward(x.clone()); + let _: Tensor, _, _, _> = dev.build_module::, TestDtype>().forward(x.clone()); + let _: Tensor, _, _, _> = dev.build_module::, TestDtype>().forward(x.clone()); + let _: Tensor, _, _, _> = dev.build_module::, TestDtype>().forward(x.clone()); + let _: Tensor, _, _, _> = dev.build_module::, TestDtype>().forward(x.clone()); + } + + #[test] + fn test_grouped_forward_sizes() { + let dev: TestDevice = Default::default(); + + let x = dev.ones::>(); + + let m = dev.build_module::, TestDtype>(); + let _: Tensor, _, _> = m.weight; + let _: Tensor, _, _> = m.forward(x.clone()); + println!("1"); + + let m = dev.build_module::, TestDtype>(); + let _: Tensor, _, _> = m.weight; + let _: Tensor, _, _> = m.forward(x.clone()); + println!("2"); + + let m = dev.build_module::, TestDtype>(); + let _: Tensor, _, _> = m.weight; + let _: Tensor, _, _> = m.forward(x.clone()); + println!("3"); + + let m = dev.build_module::, TestDtype>(); + let _: Tensor, _, _> = m.weight; + let _: Tensor, _, _> = m.forward(x.clone()); + println!("4"); + + let m = dev.build_module::, TestDtype>(); + let _: Tensor, _, _> = m.weight; + let _: Tensor, _, _> = m.forward(x); + println!("5"); + } + + #[rustfmt::skip] + #[test] + fn test_forward_4d_sizes() { + let dev: TestDevice = Default::default(); + let x = dev.zeros::>(); + let _: Tensor, _, _, _> = dev.build_module::, TestDtype>().forward(x.clone()); + let _: Tensor, _, _, _> = dev.build_module::, TestDtype>().forward(x.clone()); + let _: Tensor, _, _, _> = dev.build_module::, TestDtype>().forward(x.clone()); + let _: Tensor, _, _, _> = dev.build_module::, TestDtype>().forward(x.clone()); + let _: Tensor, _, _, _> = dev.build_module::, TestDtype>().forward(x.clone()); + let _: Tensor, _, _, _> = dev.build_module::, TestDtype>().forward(x.clone()); + let _: Tensor, _, _, _> = dev.build_module::, TestDtype>().forward(x.clone()); + let _: Tensor, _, _, _> = dev.build_module::, TestDtype>().forward(x.clone()); + let _: Tensor, _, _, _> = dev.build_module::, TestDtype>().forward(x.clone()); + } + + #[test] + fn test_2_conv_sizes() { + let dev = Cpu::default(); + type A = Conv1D<1, 2, 3>; + type B = Conv1D<2, 4, 3>; + let _: Tensor, _, _> = dev + .build_module::<(A, B), TestDtype>() + .forward(dev.zeros::>()); + } + + #[test] + fn test_3_conv_sizes() { + type A = Conv1D<1, 2, 3>; + type B = Conv1D<2, 4, 3>; + type C = Conv1D<4, 1, 1, 1, 1>; + + let dev = Cpu::default(); + let _: Tensor, _, _> = dev + .build_module::<(A, B, C), TestDtype>() + .forward_mut(dev.zeros::>()); + } + + #[test] + fn test_conv_with_optimizer() { + let dev: TestDevice = Default::default(); + + let mut m = dev.build_module::, TestDtype>(); + + let weight_init = m.weight.clone(); + + let mut opt = Sgd::new(&m, Default::default()); + let out = m.forward(dev.sample_normal::>().leaky_trace()); + let g = out.square().mean().backward(); + + assert_ne!(g.get(&m.weight).array(), [[[TestDtype::zero(); 3]; 2]; 4]); + + opt.update(&mut m, &g).expect("unused params"); + + assert_ne!(weight_init.array(), m.weight.array()); + } +} diff --git a/src/nn/conv.rs b/src/nn/conv2d.rs similarity index 100% rename from src/nn/conv.rs rename to src/nn/conv2d.rs diff --git a/src/nn/mod.rs b/src/nn/mod.rs index 17ae73fd7..08aa74cf7 100644 --- a/src/nn/mod.rs +++ b/src/nn/mod.rs @@ -189,7 +189,9 @@ mod batchnorm1d; mod batchnorm2d; mod bias2d; #[cfg(feature = "nightly")] -mod conv; +mod conv1d; +#[cfg(feature = "nightly")] +mod conv2d; #[cfg(feature = "nightly")] mod convtrans; mod dropout; @@ -243,7 +245,7 @@ pub mod modules { pub use super::batchnorm2d::BatchNorm2D; pub use super::bias2d::Bias2D; #[cfg(feature = "nightly")] - pub use super::conv::Conv2D; + pub use super::conv2d::Conv2D; #[cfg(feature = "nightly")] pub use super::convtrans::ConvTrans2D; pub use super::dropout::{Dropout, DropoutOneIn}; @@ -279,7 +281,7 @@ pub mod builders { pub use super::batchnorm2d::builder::BatchNorm2D; pub use super::bias2d::builder::Bias2D; #[cfg(feature = "nightly")] - pub use super::conv::builder::Conv2D; + pub use super::conv2d::builder::Conv2D; #[cfg(feature = "nightly")] pub use super::convtrans::builder::ConvTrans2D; pub use super::dropout::{Dropout, DropoutOneIn}; diff --git a/src/tensor_ops/conv1d/conv1d.cu b/src/tensor_ops/conv1d/conv1d.cu new file mode 100644 index 000000000..09d07e58d --- /dev/null +++ b/src/tensor_ops/conv1d/conv1d.cu @@ -0,0 +1,183 @@ +#include "cuda_fp16.h" + +struct Conv1DOp { + size_t kernel; + size_t stride; + size_t padding; + size_t dilation; + size_t groups; + size_t batch; + size_t chan_in; + size_t chan_out; + size_t l_in; + size_t l_out; +}; + +template +__device__ void unfold_input_into_patches( + const Conv1DOp op, + const T *image, // 3d (Batch, Groups * Channels, Length) + const size_t *strides, // 3d image strides + T *patches // 4d (Batch, Groups * Channels, KernelSize, LengthOut) +) { + const size_t n = op.batch * op.chan_in * op.l_out; + for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; + i += blockDim.x * gridDim.x) { + unsigned int idx = i; + const size_t ol = idx % op.l_out; + idx /= op.l_out; + const size_t c = idx % op.chan_in; + idx /= op.chan_in; + const size_t b = idx % op.batch; + + const T *image_i = image + b * strides[0] + c * strides[1]; + T *patches_i = patches + ol; + patches_i += c * (op.kernel * op.l_out); + patches_i += b * (op.chan_in * op.kernel * op.l_out); + + T zero = 0.0; + + for (int k1 = 0; k1 < op.kernel; k1++) { + const size_t y = ol * op.stride + op.dilation * k1 - op.padding; + *patches_i = (y >= op.l_in) ? zero : image_i[y * strides[2]]; + patches_i += op.l_out; + } + } +} + +template +__device__ void unfold_output_into_patches( + const Conv1DOp op, + const T *image_out, // 3d (Batch, ChanOut, LengthOut) + T *patches // 4d (Batch, ChanOut, KernelSize, LengthIn) +) { + const size_t n = op.batch * op.chan_out * op.l_in; + for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; + i += blockDim.x * gridDim.x) { + unsigned int idx = i; + const size_t y = idx % op.l_in; + idx /= op.l_in; + const size_t o = idx % op.chan_out; + idx /= op.chan_out; + const size_t b = idx % op.batch; + + const T *image_i = + image_out + b * (op.chan_out * op.l_out) + o * op.l_out; + T *patches_i = patches + y; + patches_i += o * (op.kernel * op.l_in); + patches_i += b * (op.chan_out * op.kernel * op.l_in); + + T zero = 0.0; + + for (int k1 = 0; k1 < op.kernel; k1++) { + const size_t ol_ks = y + op.padding; + const size_t ol_s = ol_ks - op.dilation * k1; + const size_t ol = ol_s / op.stride; + const bool invalid = + (ol_ks < op.dilation * k1 || ol_s % op.stride != 0 || ol >= op.l_out); + + *patches_i = invalid ? zero : image_i[ol]; + patches_i += op.l_in; + } + } +} + +template +__device__ void transpose_filters( + const Conv1DOp op, + const T *filters, // 3d (ChanOut, ChanIn/Groups, KernelSize) + const size_t *strides, // 4d filters strides + T *filters_tr // 4d (Groups, ChanIn/Groups, ChanOut/Groups, KernelSize) +) { + const size_t c_per_g = op.chan_in / op.groups; + const size_t o_per_g = op.chan_out / op.groups; + const size_t n = c_per_g * op.chan_out * op.kernel; + + for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; + i += blockDim.x * gridDim.x) { + unsigned int idx = i; + const size_t k1 = idx % op.kernel; + idx /= op.kernel; + const size_t cg = idx % c_per_g; + idx /= c_per_g; + const size_t o = idx % op.chan_out; + const size_t og = o % o_per_g; + const size_t g = o / o_per_g; + + auto i_no = o * strides[0] + cg * strides[1] + k1 * strides[2]; + T *filters_tr_i = filters_tr + k1; + filters_tr_i += og * op.kernel; + filters_tr_i += cg * (o_per_g * op.kernel); + filters_tr_i += g * (c_per_g * o_per_g * op.kernel); + *filters_tr_i = filters[i_no]; + } +} + +template +__device__ void +sum_transposed_filters(const Conv1DOp op, + const T *filters_tr, // 5d (Batch, Groups, ChanIn/Groups, + // ChanOut/Groups, KernelSize) + T *filters, // 3d (ChanOut, ChanIn/Groups, KernelSize) + const size_t *strides // 3d filter strides +) { + const size_t o_per_g = op.chan_out / op.groups; + const size_t c_per_g = op.chan_in / op.groups; + const size_t n = op.chan_out * c_per_g * op.kernel; + + for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; + i += blockDim.x * gridDim.x) { + unsigned int idx = i; + const size_t k1 = idx % op.kernel; + idx /= op.kernel; + const size_t cg = idx % c_per_g; + idx /= c_per_g; + const size_t o = idx % op.chan_out; + const size_t og = o % o_per_g; + const size_t g = o / o_per_g; + + auto i_no = o * strides[0] + cg * strides[1] + k1 * strides[2]; + + const T *filters_tr_i = filters_tr + k1; + filters_tr_i += og * op.kernel; + filters_tr_i += cg * (o_per_g * op.kernel); + filters_tr_i += g * (c_per_g * o_per_g * op.kernel); + + T tmp = 0.0; + for (int b = 0; b < op.batch; b++) { + tmp += *filters_tr_i; + filters_tr_i += n; + } + + filters[i_no] += tmp; + } +} + +#define CONV_OP(TYPENAME, UNFOLD_INPUT, UNFOLD_OUTPUT, TR_FILTERS, \ + SUM_TR_FILTERS) \ + extern "C" __global__ void UNFOLD_INPUT( \ + const Conv1DOp op, const TYPENAME *image, const size_t *strides, \ + TYPENAME *patches) { \ + unfold_input_into_patches(op, image, strides, patches); \ + } \ + extern "C" __global__ void UNFOLD_OUTPUT( \ + const Conv1DOp op, const TYPENAME *image_out, TYPENAME *patches) { \ + unfold_output_into_patches(op, image_out, patches); \ + } \ + extern "C" __global__ void TR_FILTERS( \ + const Conv1DOp op, const TYPENAME *filters, const size_t *strides, \ + TYPENAME *filters_tr) { \ + transpose_filters(op, filters, strides, filters_tr); \ + } \ + extern "C" __global__ void SUM_TR_FILTERS( \ + const Conv1DOp op, const TYPENAME *filters_tr, TYPENAME *filters, \ + const size_t *strides) { \ + sum_transposed_filters(op, filters_tr, filters, strides); \ + } + +CONV_OP(__half, unfold_input_into_patches_f16, unfold_output_into_patches_f16, + transpose_filters_f16, sum_transposed_filters_f16); +CONV_OP(float, unfold_input_into_patches_f32, unfold_output_into_patches_f32, + transpose_filters_f32, sum_transposed_filters_f32); +CONV_OP(double, unfold_input_into_patches_f64, unfold_output_into_patches_f64, + transpose_filters_f64, sum_transposed_filters_f64); diff --git a/src/tensor_ops/conv1d/cpu_kernel.rs b/src/tensor_ops/conv1d/cpu_kernel.rs new file mode 100644 index 000000000..c392a9824 --- /dev/null +++ b/src/tensor_ops/conv1d/cpu_kernel.rs @@ -0,0 +1,262 @@ +use crate::prelude::{HasAxes, HasShape}; +use crate::shapes::{Dtype, Shape}; +use crate::tensor::{cpu::*, *}; +use crate::tensor_ops::matmul::cpu_kernel::MatMulImpl; + +use super::{Conv1DKernel, Conv1DOp}; + +use std::sync::Arc; + +impl Conv1DOp { + #[inline(always)] + fn unfold_idx(&self, [k1, y]: [usize; 2]) -> Option<[usize; 1]> { + let mut ol = y + self.padding; + if ol < self.dilation * k1 { + return None; + } + ol -= self.dilation * k1; + if ol % self.stride != 0 { + return None; + } + ol /= self.stride; + if ol >= self.l_out { + return None; + } + + Some([ol]) + } +} + +impl Cpu { + #[inline] + fn fwd_conv1d( + &self, + op: &Conv1DOp, + img: &[E], + filters: &[E], + out: &mut [E], + buf: &mut [E], + ) -> Result<(), CpuError> + where + Self: MatMulImpl, + { + { + let mut i = 0; + for c in 0..(op.groups * op.chan_in) { + for k1 in 0..op.kernel { + for ol in 0..op.l_out { + let y = (ol * op.stride + op.dilation * k1).wrapping_sub(op.padding); + if y < op.l_in { + buf[i] = img[c * op.l_in + y]; + } + i += 1; + } + } + } + } + println!("img: {img:?}"); + + println!("Buff {buf:?}"); + + // (G, O / G, C * K) * (G, C * K , OL) = (G, O / G, OL) + let m = op.chan_out / op.groups; + // todo: examine why this fails the test + // let k = (op.chan_in / op.groups) * op.kernel; + let k = op.chan_in * op.kernel; + + let n = op.l_out; + for g in 0..op.groups { + Self::matmul( + (m, k, n), + false, + filters[g * m * k..].as_ptr(), + [k, 1], + buf[g * k * n..].as_ptr(), + [n, 1], + out[g * m * n..].as_mut_ptr(), + [n, 1], + ); + } + + println!("Buff OUT {out:?}"); + Ok(()) + } + + #[inline] + #[allow(clippy::too_many_arguments)] + fn bwd_conv1d( + &self, + op: &Conv1DOp, + img: &[E], + grad_img: &mut [E], + filters_tr: &[E], + grad_filters_tr: &mut [E], + grad_out: &[E], + buf: &mut [E], + ) -> Result<(), CpuError> + where + Self: MatMulImpl, + { + { + let mut i = 0; + for o in 0..op.chan_out { + for k1 in 0..op.kernel { + for y in 0..op.l_in { + if let Some([ol]) = op.unfold_idx([k1, y]) { + buf[i] = grad_out[o * op.l_out + ol]; + } + i += 1; + } + } + } + } + println!("Grad Buff OUT {buf:?}"); + println!("filters {filters_tr:?}"); + + { + // img_g += filters^T * unfold(grad_out) + // (G, C, H * W) += (G, C, O/G * K * K) * (G, O/G * K * K, L) + let m = op.chan_in; + let k = (op.chan_out / op.groups) * op.kernel; + let n = op.l_in; + for g in 0..op.groups { + Self::matmul( + (m, k, n), + true, + filters_tr[g * m * k..].as_ptr(), + [k, 1], + buf[g * k * n..].as_ptr(), + [n, 1], + grad_img[g * m * n..].as_mut_ptr(), + [n, 1], + ); + } + } + + { + // weight_g^T += img * unfold(patches)^T + // (G, C, O/G * K * K) += (G, C, H * W) * (G, H * W, O/G * K) + let m = op.chan_in; + let k = op.l_in; + let n = (op.chan_out / op.groups) * op.kernel; + for g in 0..op.groups { + Self::matmul( + (m, k, n), + true, + img[g * m * k..].as_ptr(), + [k, 1], + buf[g * k * n..].as_ptr(), + [1, k], + grad_filters_tr[g * m * n..].as_mut_ptr(), + [n, 1], + ); + } + } + Ok(()) + } +} + +impl Conv1DKernel for Cpu +where + Self: MatMulImpl, +{ + fn alloc(&self, s: S) -> Result, Self::Err> { + self.try_zeros_like(&s) + } + + fn forward( + &self, + op: Conv1DOp, + lhs: &Tensor, + rhs: &Tensor, + out: &mut Tensor, + ) -> Result<(), Self::Err> { + let patches = (op.groups * op.chan_in, op.kernel, op.l_out); + let mut patches = self.try_alloc_zeros::(patches.num_elements())?; + let [lstride, ostride] = match L::NUM_DIMS { + 2 => [0; 2], + 3 => [lhs.strides[0], out.strides[0]], + _ => unreachable!(), + }; + + let lhs = lhs.data.as_ref(); + let rhs = rhs.data.as_ref(); + let out = Arc::make_mut(&mut out.data); + use crate::prelude::HasShape; + for i_batch in 0..op.batch { + let lhs_slice = &lhs[i_batch * lstride..]; + println!("LHS Slice: {lhs_slice:?}"); + println!("Patches: {patches:?}"); + println!("rhs: {:?}", rhs); + self.fwd_conv1d( + &op, + lhs_slice, + rhs, + &mut out[i_batch * ostride..], + &mut patches, + )?; + } + Ok(()) + } + + fn backward( + &self, + op: Conv1DOp, + lhs: &Tensor, + grad_lhs: &mut Self::Vec, + rhs: &Tensor, + grad_rhs: &mut Self::Vec, + out: &impl Tensorlike, + grad_out: &Self::Vec, + ) -> Result<(), Self::Err> { + let f_tr_shape = [op.groups, op.chan_in, op.chan_out / op.groups, op.kernel]; + let patches_shape = [op.chan_out, op.kernel, op.kernel, op.l_in]; + let mut patches = self.try_alloc_zeros::(patches_shape.num_elements())?; + let mut f1023 = self.try_alloc_zeros::(f_tr_shape.num_elements())?; + let mut grad_f1023 = self.try_alloc_zeros::(f_tr_shape.num_elements())?; + + { + // transpose filters in f1023 + let buf = rhs.data.as_ref(); + let mut f_idx = NdIndex::new(f_tr_shape, f_tr_shape.strides()); + while let Some((i, [g, c, o, k1])) = f_idx.next_with_idx() { + let idx = (g * (op.chan_out / op.groups) + o) * rhs.strides[0] + + c * rhs.strides[1] + + k1 * rhs.strides[2]; + f1023[i] = buf[idx]; + } + } + + let [lstride, ostride] = match L::NUM_DIMS { + 2 => [0; 2], + 3 => [lhs.strides[0], out.strides()[0]], + _ => unreachable!(), + }; + let lhs = lhs.data.as_ref(); + + for i_batch in 0..op.batch { + self.bwd_conv1d( + &op, + &lhs[i_batch * lstride..], + &mut grad_lhs[i_batch * lstride..], + &f1023, + &mut grad_f1023, + &grad_out[i_batch * ostride..], + &mut patches, + )?; + } + + { + // untranspose filters + let mut f_idx = NdIndex::new(f_tr_shape, f_tr_shape.strides()); + while let Some((i, [g, c, o, k1])) = f_idx.next_with_idx() { + let idx = (g * (op.chan_out / op.groups) + o) * rhs.strides[0] + + c * rhs.strides[1] + + k1 * rhs.strides[2]; + grad_rhs[idx] += grad_f1023[i]; + } + } + + Ok(()) + } +} diff --git a/src/tensor_ops/conv1d/cuda_kernel.rs b/src/tensor_ops/conv1d/cuda_kernel.rs new file mode 100644 index 000000000..fb5713003 --- /dev/null +++ b/src/tensor_ops/conv1d/cuda_kernel.rs @@ -0,0 +1,289 @@ +use cudarc::cublas::{CudaBlas, Gemm}; +use cudarc::driver::{DeviceRepr, LaunchAsync, ValidAsZeroBits}; + +use crate::{ + shapes::*, + tensor::{launch_cfg, Cuda, Tensor, Tensorlike}, +}; + +use core::iter::repeat; +use std::sync::Arc; + +unsafe impl DeviceRepr for super::Conv1DOp {} + +const PTX_SRC: &str = include_str!(concat!(env!("OUT_DIR"), "/conv1d.ptx")); + +trait HasCudaKernel { + const MOD: &'static str; + const FNS: &'static [&'static str]; +} + +#[cfg(feature = "f16")] +impl HasCudaKernel for Cuda { + const MOD: &'static str = "conv1d_f16"; + const FNS: &'static [&'static str] = &[ + "unfold_input_into_patches_f16", + "unfold_output_into_patches_f16", + "transpose_filters_f16", + "sum_transposed_filters_f16", + ]; +} + +impl HasCudaKernel for Cuda { + const MOD: &'static str = "conv1d_f32"; + const FNS: &'static [&'static str] = &[ + "unfold_input_into_patches_f32", + "unfold_output_into_patches_f32", + "transpose_filters_f32", + "sum_transposed_filters_f32", + ]; +} + +impl HasCudaKernel for Cuda { + const MOD: &'static str = "conv1d_f64"; + const FNS: &'static [&'static str] = &[ + "unfold_input_into_patches_f64", + "unfold_output_into_patches_f64", + "transpose_filters_f64", + "sum_transposed_filters_f64", + ]; +} + +fn make_3d(strides: S::Concrete) -> [usize; 3] { + match S::NUM_DIMS { + 2 => [0, strides[0], strides[1]], + 3 => [strides[0], strides[1], strides[2]], + _ => unreachable!("Only implemented for 2d & 3d arrays"), + } +} + +impl super::Conv1DKernel for Cuda +where + Self: HasCudaKernel, + CudaBlas: Gemm, +{ + fn alloc(&self, shape: S) -> Result, Self::Err> { + let data = unsafe { self.alloc_empty::(shape.num_elements()) }?; + Ok(self.build_tensor(shape, shape.strides(), data)) + } + fn forward( + &self, + op: super::Conv1DOp, + img: &Tensor, + fil: &Tensor, + out: &mut Tensor, + ) -> Result<(), Self::Err> { + if !self.dev.has_func(Self::MOD, Self::FNS[0]) { + self.dev.load_ptx(PTX_SRC.into(), Self::MOD, Self::FNS)?; + } + + let patches_item_numel = op.groups * op.chan_in * op.kernel * op.l_out; + let patches_numel = op.batch * patches_item_numel; + + let mut patches = unsafe { self.get_workspace::(patches_numel) }?; + let mut patches = unsafe { patches.transmute_mut::(patches_numel).unwrap() }; + + let img_strides = self.dev.htod_copy(make_3d::(img.strides).into())?; + + let out_buf = Arc::get_mut(&mut out.data).unwrap(); + + unsafe { + let unfold_fn = self.dev.get_func(Self::MOD, Self::FNS[0]).unwrap(); + let cfg = launch_cfg::<128>((op.batch * op.chan_in * op.l_out) as u32); + let params = (op, img.data.as_ref(), &img_strides, &mut patches); + unfold_fn.launch(cfg, params)?; + let mut data: Vec = repeat(E::ONE).take(patches_numel).collect(); + self.dev.dtoh_sync_copy_into(&patches, data.as_mut_slice()); + + println!("buff {:?}", data); + + // LHS (G, O/G, C/G*K) + // RHS (B, G, C/G*K, OL) + // OUT (B, G, O/G, OL) + let m = op.chan_out / op.groups; + let k = op.chan_in * op.kernel; + let n = op.l_out; + if op.groups == 1 { + // optimizing here for common case + self.gemm_batch( + (op.batch, m, k, n), + fil.data.as_ref(), + [0, k, 1], + &patches, + [k * n, n, 1], + Default::default(), + out_buf, + [m * n, n, 1], + ) + .unwrap(); + + let mut out_data: Vec = out.as_vec(); + + println!("out {:?}", out_data); + } else { + for i_batch in 0..op.batch { + self.gemm_batch( + (op.groups, m, k, n), + fil.data.as_ref(), + [m * k, k, 1], + &patches.slice(i_batch * op.groups * k * n..), + [k * n, n, 1], + Default::default(), + &mut out_buf.slice_mut(i_batch * op.groups * m * n..), + [m * n, n, 1], + ) + .unwrap(); + } + } + } + + Ok(()) + } + + fn backward( + &self, + op: super::Conv1DOp, + lhs: &Tensor, + grad_lhs: &mut Self::Vec, + rhs: &Tensor, + grad_rhs: &mut Self::Vec, + grad_stuff: &impl Tensorlike, + grad_out: &Self::Vec, + ) -> Result<(), Self::Err> { + let patches_item_numel = op.chan_out * op.kernel * op.l_in; + let patches_numel = op.batch * patches_item_numel; + let filters_numel = + op.groups * (op.chan_in / op.groups) * (op.chan_out / op.groups) * op.kernel; + + let mut patches = unsafe { self.get_workspace::(patches_numel) }?; + let mut patches = unsafe { patches.transmute_mut::(patches_numel).unwrap() }; + + let mut ftr = unsafe { self.alloc_empty::(filters_numel) }?; + let mut grad_ftr = unsafe { self.alloc_empty::(op.batch * filters_numel) }?; + let f_strides = self.dev.htod_copy(rhs.strides.into())?; + + self.par_stream.wait_for_default()?; + let mut data: Vec = repeat(E::ONE) + .take(grad_stuff.shape().num_elements()) + .collect(); + self.dev + .dtoh_sync_copy_into(&grad_out.slice(..), data.as_mut_slice()); + + println!("grad out {:?}", data); + + unsafe { + // unfold grad_out into patches + let unfold_fn = self.dev.get_func(Self::MOD, Self::FNS[1]).unwrap(); + let cfg = launch_cfg::<128>((op.batch * op.chan_out * op.l_in) as u32); + unfold_fn.launch(cfg, (op, grad_out, &mut patches))?; + } + + let mut data: Vec = repeat(E::ONE).take(patches_numel).collect(); + self.dev.dtoh_sync_copy_into(&patches, data.as_mut_slice()); + + println!("grad buff {:?}", data); + + unsafe { + // prepare filters for backward operations by + // swapping dims 0 and 1 + let tr_fn = self.dev.get_func(Self::MOD, Self::FNS[2]).unwrap(); + let cfg = launch_cfg::<128>(rhs.shape.num_elements() as u32); + tr_fn.launch_on_stream( + self.par_stream.as_ref(), + cfg, + (op, rhs.data.as_ref(), &f_strides, &mut ftr), + )?; + + self.par_stream.wait_for_default()?; + + println!("grad buff {:?}", data); + + // img_g += filters * patches + // LHS = (G, C/G, O/G*K) + // RHS = (B, G, O/G*K, L) + // OUT = (B, G, C/G, L) + let m = op.chan_in / op.groups; + let k = (op.chan_out / op.groups) * op.kernel; + let n = op.l_in; + self.blas.set_stream(Some(self.par_stream.as_ref()))?; + if op.groups == 1 { + // optimizing here for common case + self.gemm_batch( + (op.batch, m, k, n), + &ftr, + [0, k, 1], + &patches, + [k * n, n, 1], + ::ONE, + grad_lhs, + [m * n, n, 1], + ) + .unwrap(); + } else { + for i_batch in 0..op.batch { + self.gemm_batch( + (op.groups, m, k, n), + &ftr, + [m * k, k, 1], + &patches.slice(i_batch * op.groups * k * n..), + [k * n, n, 1], + ::ONE, + &mut grad_lhs.slice_mut(i_batch * op.groups * m * n..), + [m * n, n, 1], + ) + .unwrap(); + } + } + self.blas.set_stream(None)?; + } + + unsafe { + // weight_g += img * patches^T + // LHS = (B, G, C/G, L) + // RHS = (B, L, G, O/G*K) + // OUT = (B, G, C/G, O/G*K) + let m = op.chan_in / op.groups; + let k = op.l_in; + let n = (op.chan_out / op.groups) * op.kernel; + if op.groups == 1 { + // optimizing here for common case + self.gemm_batch( + (op.batch, m, k, n), + lhs.data.as_ref(), + [m * k, k, 1], + &patches, + [k * n, 1, k], + Default::default(), + &mut grad_ftr, + [m * n, n, 1], + ) + .unwrap(); + } else { + let lhs_buf = lhs.data.as_ref(); + for i_batch in 0..op.batch { + self.gemm_batch( + (op.groups, m, k, n), + &lhs_buf.slice(i_batch * op.groups * m * k..), + [m * k, k, 1], + &patches.slice(i_batch * op.groups * k * n..), + [k * n, 1, k], + Default::default(), + &mut grad_ftr.slice_mut(i_batch * op.groups * m * n..), + [m * n, n, 1], + ) + .unwrap(); + } + } + + // sum all the gradients collected in our broadcasted grad_f + // into grad_rhs + let sum_fn = self.dev.get_func(Self::MOD, Self::FNS[3]).unwrap(); + let cfg = launch_cfg::<128>(rhs.shape.num_elements() as u32); + sum_fn.launch(cfg, (op, &grad_ftr, grad_rhs, &f_strides))?; + } + + self.dev.wait_for(self.par_stream.as_ref())?; + + Ok(()) + } +} diff --git a/src/tensor_ops/conv1d/mod.rs b/src/tensor_ops/conv1d/mod.rs new file mode 100644 index 000000000..0f87dc72d --- /dev/null +++ b/src/tensor_ops/conv1d/mod.rs @@ -0,0 +1,291 @@ +use crate::{shapes::*, tensor::*, tensor_ops::ReshapeTo}; + +mod cpu_kernel; +#[cfg(all(not(feature = "cudnn"), feature = "cuda"))] +mod cuda_kernel; + +#[cfg(test)] +mod tests; + +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub(super) struct Conv1DOp { + pub kernel: usize, + pub stride: usize, + pub padding: usize, + pub dilation: usize, + pub groups: usize, + pub batch: usize, + pub chan_in: usize, + pub chan_out: usize, + pub l_in: usize, + pub l_out: usize, +} + +pub(super) trait Conv1DKernel: Storage { + fn alloc(&self, s: S) -> Result, Self::Err>; + + fn forward( + &self, + op: Conv1DOp, + lhs: &Tensor, + rhs: &Tensor, + out: &mut Tensor, + ) -> Result<(), Self::Err>; + + #[allow(clippy::too_many_arguments)] + fn backward( + &self, + op: Conv1DOp, + lhs: &Tensor, + grad_lhs: &mut Self::Vec, + rhs: &Tensor, + grad_rhs: &mut Self::Vec, + out: &impl Tensorlike, + grad_out: &Self::Vec, + ) -> Result<(), Self::Err>; +} + +/// Applies a 1d convolution to a tensor. +/// +/// [Const] dims **require nightly**: +/// ```ignore +/// #![feature(generic_const_exprs)] +/// # use dfdx::prelude::*; +/// # let dev: Cpu = Default::default(); +/// let x: Tensor, f32, _> = dev.sample_normal(); +/// let w: Tensor, f32, _> = dev.sample_normal(); +/// let y = (x, w).conv1d( +/// Const::<1>, // stride +/// Const::<0>, // padding +/// Const::<1>, // dilation +/// Const::<1>, // groups +/// ); +/// ``` +/// +/// [usize] dims can be used on stable: +/// ```rust +/// # use dfdx::prelude::*; +/// # let dev: Cpu = Default::default(); +/// let x: Tensor<_, f32, _> = dev.sample_normal_like(&( +/// 2, // batch size +/// 3, // input channels +/// 32, // length +/// )); +/// let w: Tensor<_, f32, _> = dev.sample_normal_like(&( +/// 6, // output channels +/// 3, // input channels +/// 3, // kernel size +/// )); +/// let y = (x, w).conv1d( +/// 1, // stride +/// 0, // padding +/// 1, // dilation +/// 1, // groups +/// ); +/// ``` +pub trait TryConv1D: Sized { + type Convolved; + type Error: std::fmt::Debug; + + /// Applies a 1D convolution to the input tensor. + fn conv1d( + self, + stride: Stride, + padding: Padding, + dilation: Dilation, + groups: Groups, + ) -> Self::Convolved { + self.try_conv1d(stride, padding, dilation, groups).unwrap() + } + + /// Fallibly applies a 1D convolution to the input tensor. + fn try_conv1d( + self, + stride: Stride, + padding: Padding, + dilation: Dilation, + groups: Groups, + ) -> Result; +} + +impl< + const KERNEL: usize, + const STRIDE: usize, + const PADDING: usize, + const DILATION: usize, + Groups: Dim, + const DIM: usize, + > TryConv1D, Const, Const, Groups> + for (Const, Const) +where + Const<{ (DIM + 2 * PADDING - DILATION * (KERNEL - 1) - 1) / STRIDE + 1 }>: Sized, +{ + type Convolved = Const<{ (DIM + 2 * PADDING - DILATION * (KERNEL - 1) - 1) / STRIDE + 1 }>; + type Error = std::convert::Infallible; + fn try_conv1d( + self, + _: Const, + _: Const, + _: Const, + _: Groups, + ) -> Result { + Ok(Const) + } +} + +impl + TryConv1D for (usize, Kernel) +{ + type Convolved = usize; + type Error = std::convert::Infallible; + fn try_conv1d( + self, + stride: Stride, + padding: Padding, + dilation: Dilation, + _: Groups, + ) -> Result { + let (dim, kernel) = self; + Ok((dim + 2 * padding.size() - 1) + .checked_sub(dilation.size() * (kernel.size() - 1)) + .unwrap() + / stride.size() + + 1) + } +} + +impl + TryConv1D + for ( + Tensor<(InpChan, L), E, D, T>, + Tensor<(OutChan, >::Output, Kernel), E, D>, + ) +where + InpChan: Dim, + OutChan: Dim, + Kernel: Dim, + Stride: Dim, + Padding: Dim, + Dilation: Dim, + Groups: Dim, + L: Dim, + E: Dtype, + D: Conv1DKernel + crate::tensor_ops::reshape_to::ReshapeKernel, + T: Tape, + InpChan: std::ops::Div, + >::Output: Dim, + (L, Kernel): TryConv1D, + <(L, Kernel) as TryConv1D>::Convolved: Dim, +{ + type Convolved = Tensor< + ( + OutChan, + <(L, Kernel) as TryConv1D>::Convolved, + ), + E, + D, + T, + >; + type Error = D::Err; + + fn try_conv1d( + self, + stride: Stride, + padding: Padding, + dilation: Dilation, + groups: Groups, + ) -> Result { + let (img, filters) = self; + let (inp_chan, l) = img.shape; + let img = img.try_reshape_like(&(Const::<1>, inp_chan, l))?; + let out = (img, filters).try_conv1d(stride, padding, dilation, groups)?; + let (_, out_chan, out_l) = out.shape; + out.try_reshape_like(&(out_chan, out_l)) + } +} + +impl + TryConv1D + for ( + Tensor<(Batch, InpChan, L), E, D, T>, + Tensor<(OutChan, >::Output, Kernel), E, D>, + ) +where + InpChan: Dim, + OutChan: Dim, + Kernel: Dim, + Stride: Dim, + Padding: Dim, + Dilation: Dim, + Groups: Dim, + Batch: Dim, + L: Dim, + E: Dtype, + D: Conv1DKernel, + T: Tape, + InpChan: std::ops::Div, + >::Output: Dim, + (L, Kernel): TryConv1D, + <(L, Kernel) as TryConv1D>::Convolved: Dim, +{ + type Convolved = Tensor< + ( + Batch, + OutChan, + <(L, Kernel) as TryConv1D>::Convolved, + ), + E, + D, + T, + >; + type Error = D::Err; + + fn try_conv1d( + self, + stride: Stride, + padding: Padding, + dilation: Dilation, + groups: Groups, + ) -> Result { + let (img, filters) = self; + assert_eq!(img.shape.1.size(), filters.shape.1.size() * groups.size()); + let (batch, _, l) = img.shape; + let (out_chan, inp_chan, kernel) = filters.shape; + assert!(out_chan.size() % groups.size() == 0); + if img.strides != img.shape.strides() || filters.strides != filters.shape.strides() { + panic!("Image & filter inputs to conv1d must be contiguous"); + } + let l_out = (l, kernel).conv1d(stride, padding, dilation, groups); + let op = Conv1DOp { + stride: stride.size(), + padding: padding.size(), + kernel: kernel.size(), + dilation: dilation.size(), + groups: groups.size(), + batch: batch.size(), + chan_in: inp_chan.size(), + chan_out: out_chan.size(), + l_in: l.size(), + l_out: l_out.size(), + }; + let (lhs, ltape) = img.split_tape(); + let (rhs, rtape) = filters.split_tape(); + let mut out = lhs.device.alloc((batch, out_chan, l_out))?; + let mut tape = ltape.merge(rtape); + lhs.device.forward(op, &lhs, &rhs, &mut out)?; + let lhs_ghost = lhs.ghost(); + let rhs_ghost = rhs.ghost(); + let out_ghost = out.ghost(); + tape.add_backward_op(move |grads| { + grads.try_alloc_for(&rhs_ghost)?; + grads.try_alloc_for(&lhs_ghost)?; + grads.try_alloc_for(&out_ghost)?; + let (grad_lhs, grad_rhs, grad_out) = + grads.muts_and_ref(&lhs_ghost, &rhs_ghost, &out_ghost); + lhs.device + .backward(op, &lhs, grad_lhs, &rhs, grad_rhs, &out_ghost, grad_out) + }); + Ok(out.put_tape(tape)) + } +} diff --git a/src/tensor_ops/conv1d/tests.rs b/src/tensor_ops/conv1d/tests.rs new file mode 100644 index 000000000..f5f4632dc --- /dev/null +++ b/src/tensor_ops/conv1d/tests.rs @@ -0,0 +1,294 @@ +use super::*; +use crate::{tensor_ops::*, tests::*}; + +#[test] +/// Produced by +/// ```python +/// q = torch.nn.Conv1d(1, 2, 2) +/// x = torch.sample_normal(1, 2, 3, requires_grad=True) +/// q(x).exp().mean().backward() +/// ``` +fn test_conv1d_default_stride_and_padding() { + let dev: TestDevice = Default::default(); + let weight = dev + .tensor([[[0.02272552, 0.5804308]], [[0.7020492, 0.05399668]]]) + .to_dtype::(); + let bias = dev.tensor([-0.3903233, 0.45088166]).to_dtype::(); + let x = dev + .tensor([[-1.6405383, 1.1242371]]) + .to_dtype::(); + let result = (x.leaky_trace(), weight.clone()) + .conv1d(Const::<1>, Const::<0>, Const::<1>, Const::<1>) + + bias.leaky_trace().broadcast::<_, _>(); + assert_close_to_literal!(result, [[0.2249364], [-0.6401519]]); + let g = result.exp().mean().backward(); + assert_close_to_literal!(g.get(&x), [[0.19929343, 0.37765408]]); + assert_close_to_literal!( + g.get(&weight), + [[[-1.0271764, 0.70390904]], [[-0.43245602, 0.2963558]]] + ); + assert_close_to_literal!(g.get(&bias), [0.6261215, 0.26360616]); +} + +#[test] +/// Produced by +/// ```python +/// q = torch.nn.Conv1d(1, 2, 2, stride=2) +/// x = torch.sample_normal(1, 2, 3, requires_grad=True) +/// q(x).exp().mean().backward() +/// ``` +fn test_conv1d_stride_2() { + let dev: TestDevice = Default::default(); + let weight = dev + .tensor([[[-0.4296614, 0.27693725]], [[-0.3809104, 0.19169092]]]) + .to_dtype::(); + let bias = dev + .tensor([-0.29623124, -0.09120554]) + .to_dtype::(); + let x = dev + .tensor([[-0.31544453, 0.47184715]]) + .to_dtype::(); + + let result = (x.leaky_trace(), weight.clone()) + .conv1d(Const::<2>, Const::<0>, Const::<1>, Const::<1>) + + bias.leaky_trace().broadcast::<_, _>(); + assert_close_to_literal!(result, [[-0.03002486], [0.11939937]]); + + let g = result.exp().mean().backward(); + + assert_close_to_literal!(g.get(&x), [[-0.42308503, 0.2423735]]); + + assert_close_to_literal!( + g.get(&weight), + [[[-0.15305707, 0.22894529]], [[-0.17772458, 0.26584336]]] + ); + + assert_close_to_literal!(g.get(&bias), [0.48521072, 0.5634099]); +} + +#[test] +fn test_conv1d_padding_1() { + let dev: TestDevice = Default::default(); + let weight = dev + .tensor([ + [[0.45220423, -0.3358205], [0.16038167, 0.09695387]], + [[0.19551754, -0.3192072], [-0.49848652, -0.49257886]], + [[0.21106702, 0.40513265], [0.08618081, -0.15866321]], + ]) + .to_dtype::(); + let bias = dev + .tensor([-0.01069266, 0.22007078, -0.4849882]) + .to_dtype::(); + let x = dev + .tensor([[0.10943512, -1.7794625], [1.1263468, 0.5267281]]) + .to_dtype::(); + + let result = (x.leaky_trace(), weight.clone()) + .conv1d(Const::<1>, Const::<1>, Const::<1>, Const::<1>) + + bias.leaky_trace().broadcast::<_, Axis<1>>(); + + assert_close_to_literal!( + result, + [ + [0.06176047, 0.868088, -0.7308956], + [-0.36967635, -0.01143935, -0.3904122], + [-0.6193623, -1.1693113, -0.8151802] + ] + ); + + let g = result.exp().mean().backward(); + + assert_close_to_literal!( + g.get(&x), + [[0.10849564, -0.06070386], [-0.04517691, -0.05858667]] + ); + + assert_close_to_literal!( + g.get(&weight), + [ + [[-0.06622871, -0.45809978], [0.32632908, 0.27255058]], + [[-0.12179004, -0.1870675], [0.16333483, 0.1443328]], + [[-0.08372552, -0.05486213], [0.06477002, 0.08554335]] + ] + ); + + assert_close_to_literal!(g.get(&bias), [0.43639296, 0.26181796, 0.14349198]); +} + +#[test] +fn test_conv1d_stride_3_padding_4() { + let dev: TestDevice = Default::default(); + let weight = dev + .tensor([ + [[-0.4961109, -0.41855216, -0.31035745]], + [[-0.28658125, 0.09752917, -0.4264508]], + ]) + .to_dtype::(); + let bias = dev.tensor([0.04796273, 0.17420131]).to_dtype::(); + let x = dev + .tensor([[0.09930344, 1.0408987]]) + .to_dtype::(); + + let result = (x.leaky_trace(), weight.clone()) + .conv1d(Const::<3>, Const::<4>, Const::<1>, Const::<1>) + + bias.leaky_trace().broadcast::<_, _>(); + + assert_close_to_literal!( + result, + [ + [0.04796273, -0.31665158, 0.04796273], + [0.17420131, -0.26000577, 0.17420131] + ] + ); + + let g = result.exp().mean().backward(); + + assert_close_to_literal!(g.get(&x), [[-0.03829185, -0.09248922]]); + + assert_close_to_literal!( + g.get(&weight), + [ + [[0., 0.01205849, 0.12639713]], + [[0., 0.01276127, 0.13376366]] + ] + ); + + assert_close_to_literal!(g.get(&bias), [0.4711413, 0.5252729]); +} + +#[test] +fn test_conv1d_s4p3k2() { + let dev = TestDevice::seed_from_u64(432); + + let weight: Tensor, TestDtype, _> = dev.sample_normal(); + println!("weight data {:?}", weight.as_vec()); + let bias: Tensor, TestDtype, _> = dev.sample_normal(); + println!("bias data {:?}", bias.as_vec()); + let x: Tensor, TestDtype, _> = dev.sample_normal(); + println!("x data {:?}", x.as_vec()); + + let out = + (x.leaky_trace(), weight.clone()).conv1d(Const::<4>, Const::<3>, Const::<1>, Const::<1>); + let out = out + bias.broadcast::<_, Axis<1>>(); + println!("out data {:?}, {:?}", out.as_vec(), out.shape()); + + assert_close_to_literal!( + out, + [ + [0.44691145, 1.3863211, -2.0541177], + [0.1279889, -0.96598804, 1.6030374], + [-0.66274095, -1.2659106, -0.38152635], + ] + ); +} + +#[test] +fn test_batched_conv1d() { + let dev: TestDevice = Default::default(); + let x: Tensor, TestDtype, _> = dev.sample_normal(); + let w: Tensor, TestDtype, _> = dev.sample_normal(); + + let y: Tensor, _, _, _> = + (x.leaky_trace(), w.clone()).conv1d(Const::<3>, Const::<2>, Const::<1>, Const::<1>); + let y0 = y.retaped::(); + let grads0 = y.square().mean().backward(); + let x0 = grads0.get(&x); + let w0 = grads0.get(&w); + + let x = x + .broadcast::, _>() + .reshape::>(); + assert_eq!(x.strides, x.shape.strides()); + + let y: Tensor, _, _, _> = + (x.leaky_trace(), w.clone()).conv1d(Const::<3>, Const::<2>, Const::<1>, Const::<1>); + for i in 0..10 { + assert_close_to_tensor!(y0, y.retaped::().select(dev.tensor(i))); + } + + let grads = y.square().mean().backward(); + + assert_close_to_tensor!(w0, grads.get(&w), 1e-3); + + let x_grad = grads.get(&x) * 10.0; + for i in 0..10 { + assert_close_to_tensor!(x0, x_grad.clone().select(dev.tensor(i))); + } +} + +#[test] +fn test_conv1d_dilated() { + let dev: TestDevice = Default::default(); + let x = dev.tensor([[0., 1., 2., 4., 5.]]).to_dtype::(); + let w = dev.tensor([[[0.1, 0.5]]]).to_dtype::(); + let y = (x.leaky_trace(), w.clone()).conv1d(Const::<1>, Const::<0>, Const::<2>, Const::<1>); + assert_close_to_literal!(y, [[1.0, 2.1, 2.7]]); + let grads = y.mean().backward(); + assert_close_to_literal!( + grads.get(&x), + [[0.03333335, 0.03333335, 0.2, 0.1666667, 0.1666667],] + ); + assert_close_to_literal!(grads.get(&w), [[[1.0, 11.0 / 3.0]]]); +} + +#[test] +fn test_conv1d_grouped_forward() { + const NUM_GROUPS: usize = 3; + let dev: TestDevice = Default::default(); + let x: Tensor, TestDtype, _> = dev.sample_normal(); + let w: Tensor, TestDtype, _> = dev.sample_normal(); + + let y = (x.leaky_trace(), w.clone()).conv1d( + Const::<1>, + Const::<0>, + Const::<1>, + Const::, + ); + + for i in 0..NUM_GROUPS { + let x_group = x + .clone() + .slice((.., 3 * i..3 * (i + 1), ..)) + .realize::<(Const<2>, Const<3>, Const<14>)>(); + let w_group = w + .clone() + .slice((5 * i..5 * (i + 1), .., ..)) + .realize::<(Const<5>, Const<3>, Const<3>)>(); + let y_group = (x_group, w_group).conv1d(Const::<1>, Const::<0>, Const::<1>, Const::<1>); + let y_group_true = y + .retaped::() + .slice((.., 5 * i..5 * (i + 1), ..)) + .realize::<(Const<2>, Const<5>, Const<12>)>(); + assert_close_to_tensor!(y_group, y_group_true); + } + + let grads = y.exp().sum().backward(); + let x_grad = grads.get(&x); + let w_grad = grads.get(&w); + + for i in 0..NUM_GROUPS { + let x_group = x + .clone() + .slice((.., 3 * i..3 * (i + 1), ..)) + .realize::<(Const<2>, Const<3>, Const<14>)>(); + let w_group = w + .clone() + .slice((5 * i..5 * (i + 1), .., ..)) + .realize::<(Const<5>, Const<3>, Const<3>)>(); + let y_group = (x_group.leaky_trace(), w_group.clone()) + .conv1d(Const::<1>, Const::<0>, Const::<1>, Const::<1>); + let grads = y_group.exp().sum().backward(); + + let x_grad_group_true = x_grad + .clone() + .slice((.., 3 * i..3 * (i + 1), ..)) + .realize::<(Const<2>, Const<3>, Const<14>)>(); + let w_grad_group_true = w_grad + .clone() + .slice((5 * i..5 * (i + 1), .., ..)) + .realize::<(Const<5>, Const<3>, Const<3>)>(); + + assert_close_to_tensor!(grads.get(&x_group), x_grad_group_true); + assert_close_to_tensor!(grads.get(&w_group), w_grad_group_true); + } +} diff --git a/src/tensor_ops/conv2d/cpu_kernel.rs b/src/tensor_ops/conv2d/cpu_kernel.rs index 8342e77ec..5edf859b1 100644 --- a/src/tensor_ops/conv2d/cpu_kernel.rs +++ b/src/tensor_ops/conv2d/cpu_kernel.rs @@ -41,7 +41,7 @@ impl Conv2DOp { impl Cpu { #[inline] - fn fwd( + fn fwd_conv2d( &self, op: &Conv2DOp, img: &[E], @@ -96,7 +96,7 @@ impl Cpu { #[inline] #[allow(clippy::too_many_arguments)] - fn bwd( + fn bwd_conv2d( &self, op: &Conv2DOp, img: &[E], @@ -197,7 +197,7 @@ where let rhs = rhs.data.as_ref(); let out = Arc::make_mut(&mut out.data); for i_batch in 0..op.batch { - self.fwd( + self.fwd_conv2d( &op, &lhs[i_batch * lstride..], rhs, @@ -251,7 +251,7 @@ where let lhs = lhs.data.as_ref(); for i_batch in 0..op.batch { - self.bwd( + self.bwd_conv2d( &op, &lhs[i_batch * lstride..], &mut grad_lhs[i_batch * lstride..], diff --git a/src/tensor_ops/conv2d/cuda_kernel.rs b/src/tensor_ops/conv2d/cuda_kernel.rs index 4836ed65c..e44610b50 100644 --- a/src/tensor_ops/conv2d/cuda_kernel.rs +++ b/src/tensor_ops/conv2d/cuda_kernel.rs @@ -108,7 +108,7 @@ where // RHS (B, G, C/G*K*K, OH*OW) // OUT (B, G, O/G, OH*OW) let m = op.chan_out / op.groups; - let k = (op.chan_in / op.groups) * op.kernel * op.kernel; + let k = op.chan_in * op.kernel * op.kernel; let n = op.h_out * op.w_out; if op.groups == 1 { // optimizing here for common case @@ -129,7 +129,7 @@ where (op.groups, m, k, n), fil.data.as_ref(), [m * k, k, 1], - &patches.slice(i_batch * op.groups * k * n..), + &patches.slice(i_batch * k * n..), [k * n, n, 1], Default::default(), &mut out_buf.slice_mut(i_batch * op.groups * m * n..), diff --git a/src/tensor_ops/matmul/cuda_kernel.rs b/src/tensor_ops/matmul/cuda_kernel.rs index b6787d848..253e89721 100644 --- a/src/tensor_ops/matmul/cuda_kernel.rs +++ b/src/tensor_ops/matmul/cuda_kernel.rs @@ -40,6 +40,17 @@ fn gemm_cfg( beta, ldc: out_stride as i32, }; + println!( + "TRUE! lda: {}, ldb {}, ldc: {}, transa: {:?}, transb: {:?} {}, {}, {}", + cfg.lda, + cfg.ldb, + cfg.ldc, + cfg.transa, + cfg.transb, + m.size(), + k.size(), + n.size(), + ); (cfg, true) } else { // out is stored in column major format @@ -55,6 +66,15 @@ fn gemm_cfg( beta, ldc: out_stride as i32, }; + println!( + "FALSE! lda: {}, ldb {}, ldc: {}, {}, {}, {}", + cfg.lda, + cfg.ldb, + cfg.ldc, + m.size(), + k.size(), + n.size(), + ); (cfg, false) } } diff --git a/src/tensor_ops/mod.rs b/src/tensor_ops/mod.rs index 6607a1ab1..ce6fc055e 100644 --- a/src/tensor_ops/mod.rs +++ b/src/tensor_ops/mod.rs @@ -277,6 +277,11 @@ pub use var_to::VarTo; pub(crate) use to_dtype::ToDtypeKernel; pub(crate) use upscale2d::Upscale2DKernel; +#[cfg(feature = "nightly")] +mod conv1d; +#[cfg(feature = "nightly")] +pub use conv1d::TryConv1D; + #[cfg(feature = "nightly")] mod conv2d; #[cfg(feature = "nightly")]