Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Breaking] Adding dilation/groups to ConvTranspose2D #783

Merged
merged 11 commits into from
Jul 5, 2023
99 changes: 80 additions & 19 deletions src/nn/convtrans.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,29 @@ pub mod builder {
const KERNEL_SIZE: usize,
const STRIDE: usize = 1,
const PADDING: usize = 0,
const DILATION: usize = 1,
const GROUPS: usize = 1,
>;
}

impl<const I: usize, const O: usize, const K: usize, const S: usize, const P: usize, E, D>
BuildOnDevice<D, E> for builder::ConvTrans2D<I, O, K, S, P>
impl<
const I: usize,
const O: usize,
const K: usize,
const S: usize,
const P: usize,
const L: usize,
const G: usize,
E,
D,
> BuildOnDevice<D, E> for builder::ConvTrans2D<I, O, K, S, P, L, G>
where
E: Dtype,
D: Device<E>,
ConvTrans2D<I, O, K, S, P, E, D>: BuildModule<D, E>,
Const<{ O / G }>: Sized,
ConvTrans2D<I, O, K, S, P, L, G, E, D>: BuildModule<D, E>,
{
type Built = ConvTrans2D<I, O, K, S, P, E, D>;
type Built = ConvTrans2D<I, O, K, S, P, L, G, E, D>;
fn try_build_on_device(device: &D) -> Result<Self::Built, <D>::Err> {
Self::Built::try_build(device)
}
Expand All @@ -45,26 +57,43 @@ where
/// - `KERNEL_SIZE`: The size of the kernel applied to both width and height of the images.
/// - `STRIDE`: How far to move the kernel each step. Defaults to `1`
/// - `PADDING`: How much zero padding to add around the images. Defaults to `0`.
/// - `DILATION`: Controls the spacing between kernel points. Defaults to `1`.
/// - `GROUPS`: Controls the connections between inputs and outputs.
/// `IN_CHAN` and `OUT_CHAN` must both be divisible by `GROUPS`. For example,
#[derive(Debug, Clone)]
pub struct ConvTrans2D<
const IN_CHAN: usize,
const OUT_CHAN: usize,
const KERNEL_SIZE: usize,
const STRIDE: usize,
const PADDING: usize,
const DILATION: usize,
const GROUPS: usize,
E: Dtype,
D: Storage<E>,
> {
pub weight: Tensor<Rank4<OUT_CHAN, IN_CHAN, KERNEL_SIZE, KERNEL_SIZE>, E, D>,
> where
Const<{ OUT_CHAN / GROUPS }>: Sized,
{
pub weight: Tensor<Rank4<IN_CHAN, { OUT_CHAN / GROUPS }, KERNEL_SIZE, KERNEL_SIZE>, E, D>,
}

impl<const I: usize, const O: usize, const K: usize, const S: usize, const P: usize, E, D>
TensorCollection<E, D> for ConvTrans2D<I, O, K, S, P, E, D>
impl<
const I: usize,
const O: usize,
const K: usize,
const S: usize,
const P: usize,
const L: usize,
const G: usize,
E,
D,
> TensorCollection<E, D> for ConvTrans2D<I, O, K, S, P, L, G, E, D>
where
E: Dtype + Float + SampleUniform,
D: Device<E>,
Const<{ O / G }>: Sized,
{
type To<E2: Dtype, D2: Device<E2>> = ConvTrans2D<I, O, K, S, P, E2, D2>;
type To<E2: Dtype, D2: Device<E2>> = ConvTrans2D<I, O, K, S, P, L, G, E2, D2>;

fn iter_tensors<V: ModuleVisitor<Self, E, D>>(
visitor: &mut V,
Expand All @@ -85,26 +114,58 @@ where
}

#[cfg(feature = "nightly")]
impl<const C: usize, const O: usize, const K: usize, const S: usize, const P: usize, E, D, Img>
Module<Img> for ConvTrans2D<C, O, K, S, P, E, D>
impl<
const C: usize,
const O: usize,
const K: usize,
const S: usize,
const P: usize,
const L: usize,
const G: usize,
E,
D,
Img,
> Module<Img> for ConvTrans2D<C, O, K, S, P, L, G, E, D>
where
E: Dtype,
D: Device<E>,
Img: TryConvTrans2DTo<Tensor<Rank4<O, C, K, K>, E, D>, S, P> + HasErr<Err = D::Err>,
Const<{ O / G }>: Sized,
(Img, Tensor<Rank4<C, { O / G }, K, K>, E, D>):
TryConvTrans2D<Const<S>, Const<P>, Const<L>, Const<G>>,
{
type Output = Img::Output;
type Error = D::Err;
type Output = <(Img, Tensor<Rank4<C, { O / G }, K, K>, E, D>) as TryConvTrans2D<
Const<S>,
Const<P>,
Const<L>,
Const<G>,
>>::Convolved;
type Error = <(Img, Tensor<Rank4<C, { O / G }, K, K>, E, D>) as TryConvTrans2D<
Const<S>,
Const<P>,
Const<L>,
Const<G>,
>>::Error;

fn try_forward(&self, x: Img) -> Result<Self::Output, D::Err> {
x.try_convtrans2d_to(self.weight.clone())
fn try_forward(&self, x: Img) -> Result<Self::Output, Self::Error> {
(x, self.weight.clone()).try_convtrans2d(Const, Const, Const, Const)
}
}

impl<const I: usize, const O: usize, const K: usize, const S: usize, const P: usize, E, D>
NonMutableModule for ConvTrans2D<I, O, K, S, P, E, D>
impl<
const I: usize,
const O: usize,
const K: usize,
const S: usize,
const P: usize,
const L: usize,
const G: usize,
E,
D,
> NonMutableModule for ConvTrans2D<I, O, K, S, P, L, G, E, D>
where
E: Dtype,
D: Storage<E>,
Const<{ O / G }>: Sized,
{
}

Expand Down Expand Up @@ -187,7 +248,7 @@ mod tests {

assert_ne!(
g.get(&m.weight).array(),
[[[[TestDtype::zero(); 3]; 3]; 2]; 4]
[[[[TestDtype::zero(); 3]; 3]; 4]; 2]
);

opt.update(&mut m, &g).expect("unused params");
Expand Down
1 change: 1 addition & 0 deletions src/nn/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,7 @@ mod batchnorm2d;
mod bias2d;
#[cfg(feature = "nightly")]
mod conv;
#[cfg(feature = "nightly")]
mod convtrans;
mod dropout;
mod ema;
Expand Down
46 changes: 24 additions & 22 deletions src/tensor_ops/conv2d/conv2d.cu
Original file line number Diff line number Diff line change
Expand Up @@ -22,29 +22,29 @@ __device__ void unfold_input_into_patches(
const size_t *strides, // 4d image strides
T *patches // 6d (Batch, Groups * Channels, KernelSize, KernelSize, HeightOut, WidthOut)
) {
const size_t n = op.batch * op.groups * op.chan_in * op.h_out * op.w_out;
const size_t n = op.batch * op.chan_in * op.h_out * op.w_out;
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; i += blockDim.x * gridDim.x) {
unsigned int idx = i;
const size_t ow = idx % op.w_out;
idx /= op.w_out;
const size_t oh = idx % op.h_out;
idx /= op.h_out;
const size_t c = idx % (op.chan_in * op.groups);
idx /= (op.chan_in * op.groups);
const size_t c = idx % op.chan_in;
idx /= op.chan_in;
const size_t b = idx % op.batch;

const T *image_i = image + b * strides[0] + c * strides[1];
T *patches_i = patches + oh * op.w_out + ow;
patches_i += c * (op.kernel * op.kernel * op.h_out * op.w_out);
patches_i += b * (op.groups * op.chan_in * op.kernel * op.kernel * op.h_out * op.w_out);
patches_i += b * (op.chan_in * op.kernel * op.kernel * op.h_out * op.w_out);

T zero = 0.0;

for (int k1 = 0;k1 < op.kernel;k1++) {
const size_t y = oh * op.stride + op.dilation * k1 - op.padding;
for (int k2 = 0;k2 < op.kernel;k2++) {
const size_t x = ow * op.stride + op.dilation * k2 - op.padding;
*patches_i = (y >= op.h_in || x >= op.w_in) ? zero : image[y * strides[2] + x * strides[3]];
*patches_i = (y >= op.h_in || x >= op.w_in) ? zero : image_i[y * strides[2] + x * strides[3]];
patches_i += op.h_out * op.w_out;
}
}
Expand Down Expand Up @@ -86,7 +86,7 @@ __device__ void unfold_output_into_patches(
const size_t ow = ow_s / op.stride;

const bool invalid = k1_invalid || (ow_ks < op.dilation * k2 || ow_s % op.stride != 0 || ow >= op.w_out);
*patches_i = invalid ? zero : image_out[oh * op.w_out + ow];
*patches_i = invalid ? zero : image_i[oh * op.w_out + ow];
patches_i += op.h_in * op.w_in;
}
}
Expand All @@ -96,64 +96,66 @@ __device__ void unfold_output_into_patches(
template<typename T>
__device__ void transpose_filters(
const Conv2DOp op,
const T *filters, // 4d (ChanOut, ChanIn, KernelSize, KernelSize)
const T *filters, // 4d (ChanOut, ChanIn/Groups, KernelSize, KernelSize)
const size_t *strides, // 4d filters strides
T *filters_tr // 5d (Groups, ChanIn, ChanOut/Groups, KernelSize, KernelSize)
T *filters_tr // 5d (Groups, ChanIn/Groups, ChanOut/Groups, KernelSize, KernelSize)
) {
const size_t n = op.chan_in * op.chan_out * op.kernel * op.kernel;
const size_t c_per_g = op.chan_in / op.groups;
const size_t o_per_g = op.chan_out / op.groups;
const size_t n = c_per_g * op.chan_out * op.kernel * op.kernel;

for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; i += blockDim.x * gridDim.x) {
unsigned int idx = i;
const size_t k2 = idx % op.kernel;
idx /= op.kernel;
const size_t k1 = idx % op.kernel;
idx /= op.kernel;
const size_t c = idx % op.chan_in;
idx /= op.chan_in;
const size_t cg = idx % c_per_g;
idx /= c_per_g;
const size_t o = idx % op.chan_out;
const size_t og = o % o_per_g;
const size_t g = o / o_per_g;

auto i_no = o * strides[0] + c * strides[1] + k1 * strides[2] + k2 * strides[3];
auto i_no = o * strides[0] + cg * strides[1] + k1 * strides[2] + k2 * strides[3];
T *filters_tr_i = filters_tr + k2;
filters_tr_i += k1 * op.kernel;
filters_tr_i += og * (op.kernel * op.kernel);
filters_tr_i += c * (o_per_g * op.kernel * op.kernel);
filters_tr_i += g * (op.chan_in * o_per_g * op.kernel * op.kernel);
filters_tr_i += cg * (o_per_g * op.kernel * op.kernel);
filters_tr_i += g * (c_per_g * o_per_g * op.kernel * op.kernel);
*filters_tr_i = filters[i_no];
}
}

template<typename T>
__device__ void sum_transposed_filters(
const Conv2DOp op,
const T *filters_tr, // 6d (Batch, Groups, ChanIn, ChanOut/Groups, KernelSize, KernelSize)
T *filters, // 4d (ChanOut, ChanIn, KernelSize, KernelSize)
const T *filters_tr, // 6d (Batch, Groups, ChanIn/Groups, ChanOut/Groups, KernelSize, KernelSize)
T *filters, // 4d (ChanOut, ChanIn/Groups, KernelSize, KernelSize)
const size_t *strides // 4d filter strides
) {
const size_t n = op.chan_out * op.chan_in * op.kernel * op.kernel;
const size_t o_per_g = op.chan_out / op.groups;
const size_t c_per_g = op.chan_in / op.groups;
const size_t n = op.chan_out * c_per_g * op.kernel * op.kernel;

for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; i += blockDim.x * gridDim.x) {
unsigned int idx = i;
const size_t k2 = idx % op.kernel;
idx /= op.kernel;
const size_t k1 = idx % op.kernel;
idx /= op.kernel;
const size_t c = idx % op.chan_in;
idx /= op.chan_in;
const size_t cg = idx % c_per_g;
idx /= c_per_g;
const size_t o = idx % op.chan_out;
const size_t og = o % o_per_g;
const size_t g = o / o_per_g;

auto i_no = o * strides[0] + c * strides[1] + k1 * strides[2] + k2 * strides[3];
auto i_no = o * strides[0] + cg * strides[1] + k1 * strides[2] + k2 * strides[3];

const T *filters_tr_i = filters_tr + k2;
filters_tr_i += k1 * op.kernel;
filters_tr_i += og * (op.kernel * op.kernel);
filters_tr_i += c * (o_per_g * op.kernel * op.kernel);
filters_tr_i += g * (op.chan_in * o_per_g * op.kernel * op.kernel);
filters_tr_i += cg * (o_per_g * op.kernel * op.kernel);
filters_tr_i += g * (c_per_g * o_per_g * op.kernel * op.kernel);

T tmp = 0.0;
for (int b = 0; b < op.batch; b++) {
Expand Down
38 changes: 17 additions & 21 deletions src/tensor_ops/conv2d/cpu_kernel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ impl Cpu {
{
{
let mut i = 0;
for c in 0..(op.groups * op.chan_in) {
for c in 0..op.chan_in {
for k1 in 0..op.kernel {
for k2 in 0..op.kernel {
for oh in 0..op.h_out {
Expand All @@ -73,9 +73,11 @@ impl Cpu {
}
}

// (G, O / G, C * K * K) * (G, C * K * K, OH * OW) = (G, O / G, OH * OW)
// filters: (G, O/G, C/G*K*K)
// buf: (G, C/G*K*K, OH*OW)
// output: (G, O/G, OH*OW)
let m = op.chan_out / op.groups;
let k = op.chan_in * op.kernel * op.kernel;
let k = (op.chan_in / op.groups) * op.kernel * op.kernel;
let n = op.w_out * op.h_out;
for g in 0..op.groups {
Self::matmul(
Expand Down Expand Up @@ -128,8 +130,8 @@ impl Cpu {

{
// img_g += filters^T * unfold(grad_out)
// (G, C, H * W) += (G, C, O/G * K * K) * (G, O/G * K * K, H * W)
let m = op.chan_in;
// (G, C/G, H * W) += (G, C/G, O/G * K * K) * (G, O/G * K * K, H * W)
let m = op.chan_in / op.groups;
let k = (op.chan_out / op.groups) * op.kernel * op.kernel;
let n = op.h_in * op.w_in;
for g in 0..op.groups {
Expand All @@ -148,8 +150,8 @@ impl Cpu {

{
// weight_g^T += img * unfold(patches)^T
// (G, C, O/G * K * K) += (G, C, H * W) * (G, H * W, O/G * K * K)
let m = op.chan_in;
// (G, C/G, O/G * K * K) += (G, C/G, H * W) * (G, H * W, O/G * K * K)
let m = op.chan_in / op.groups;
let k = op.h_in * op.w_in;
let n = (op.chan_out / op.groups) * op.kernel * op.kernel;
for g in 0..op.groups {
Expand Down Expand Up @@ -184,13 +186,7 @@ where
rhs: &Tensor<R, E, Self>,
out: &mut Tensor<O, E, Self>,
) -> Result<(), Self::Err> {
let patches = (
op.groups * op.chan_in,
op.kernel,
op.kernel,
op.h_out,
op.w_out,
);
let patches = (op.chan_in, op.kernel, op.kernel, op.h_out, op.w_out);
let mut patches = self.try_alloc_zeros::<E>(patches.num_elements())?;
let [lstride, ostride] = match L::NUM_DIMS {
3 => [0; 2],
Expand Down Expand Up @@ -224,7 +220,7 @@ where
) -> Result<(), Self::Err> {
let f_tr_shape = [
op.groups,
op.chan_in,
op.chan_in / op.groups,
op.chan_out / op.groups,
op.kernel,
op.kernel,
Expand All @@ -238,9 +234,9 @@ where
// transpose filters in f1023
let buf = rhs.data.as_ref();
let mut f_idx = NdIndex::new(f_tr_shape, f_tr_shape.strides());
while let Some((i, [g, c, o, k1, k2])) = f_idx.next_with_idx() {
let idx = (g * (op.chan_out / op.groups) + o) * rhs.strides[0]
+ c * rhs.strides[1]
while let Some((i, [g, c_over_g, o_over_g, k1, k2])) = f_idx.next_with_idx() {
let idx = (g * (op.chan_out / op.groups) + o_over_g) * rhs.strides[0]
+ c_over_g * rhs.strides[1]
+ k1 * rhs.strides[2]
+ k2 * rhs.strides[3];
f1023[i] = buf[idx];
Expand Down Expand Up @@ -269,9 +265,9 @@ where
{
// untranspose filters
let mut f_idx = NdIndex::new(f_tr_shape, f_tr_shape.strides());
while let Some((i, [g, c, o, k1, k2])) = f_idx.next_with_idx() {
let idx = (g * (op.chan_out / op.groups) + o) * rhs.strides[0]
+ c * rhs.strides[1]
while let Some((i, [g, c_over_g, o_over_g, k1, k2])) = f_idx.next_with_idx() {
let idx = (g * (op.chan_out / op.groups) + o_over_g) * rhs.strides[0]
+ c_over_g * rhs.strides[1]
+ k1 * rhs.strides[2]
+ k2 * rhs.strides[3];
grad_rhs[idx] += grad_f1023[i];
Expand Down
Loading
Loading