Skip to content

Commit

Permalink
Updating cuda kernel for conv transpose
Browse files Browse the repository at this point in the history
  • Loading branch information
coreylowman committed May 12, 2023
1 parent 3958d6e commit 71fcada
Show file tree
Hide file tree
Showing 2 changed files with 132 additions and 57 deletions.
55 changes: 35 additions & 20 deletions src/tensor_ops/convtrans2d/convtrans2d.cu
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
#include "cuda_fp16.h"

struct Conv2DOp {
size_t kernel;
size_t stride;
size_t padding;
size_t kernel;
size_t dilation;
size_t groups;
size_t batch;
size_t chan_in;
size_t chan_out;
Expand All @@ -16,12 +18,12 @@ struct Conv2DOp {
template<typename T>
__device__ void unfold_input_into_patches(
const Conv2DOp op,
const T *image, // 4d (Batch, Channels, Height, Width)
const T *image, // 4d (Batch, Groups * Channels, Height, Width)
const size_t *strides, // 4d image strides
T *patches // 6d (Batch, Channels, KernelSize, KernelSize, HeightOut, WidthOut)
T *patches // 6d (Batch, Groups * Channels, KernelSize, KernelSize, HeightOut, WidthOut)
) {
unsigned int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i >= op.batch * op.chan_in * op.h_out * op.w_out) {
if (i >= op.batch * op.groups * op.chan_in * op.h_out * op.w_out) {
return;
}

Expand All @@ -30,8 +32,8 @@ __device__ void unfold_input_into_patches(
idx /= op.w_out;
const size_t oh = idx % op.h_out;
idx /= op.h_out;
const size_t c = idx % op.chan_in;
idx /= op.chan_in;
const size_t c = idx % (op.chan_in * op.groups);
idx /= (op.chan_in * op.groups);
const size_t b = idx % op.batch;
idx /= op.batch;

Expand All @@ -44,15 +46,15 @@ __device__ void unfold_input_into_patches(

for (int k1 = 0;k1 < op.kernel;k1++) {
const size_t y_ks = oh + op.padding;
const size_t y_s = y_ks - k1;
const size_t y_s = y_ks - op.dilation * k1;
const size_t y = y_s / op.stride;
const bool k1_invalid = (y_ks < k1 || y_s % op.stride != 0 || y >= op.h_in);
const bool k1_invalid = (y_ks < op.dilation * k1 || y_s % op.stride != 0 || y >= op.h_in);
for (int k2 = 0;k2 < op.kernel;k2++) {
const size_t x_ks = ow + op.padding;
const size_t x_s = x_ks - k2;
const size_t x_s = x_ks - op.dilation * k2;
const size_t x = x_s / op.stride;

const bool invalid = k1_invalid || (x_ks < k2 || x_s % op.stride != 0 || x >= op.w_in);
const bool invalid = k1_invalid || (x_ks < op.dilation * k2 || x_s % op.stride != 0 || x >= op.w_in);
*patches = invalid ? zero : image[y * strides[2] + x * strides[3]];
patches += op.h_out * op.w_out;
}
Expand Down Expand Up @@ -87,9 +89,9 @@ __device__ void unfold_output_into_patches(
T zero = 0.0;

for (int k1 = 0;k1 < op.kernel;k1++) {
const size_t oh = y * op.stride + k1 - op.padding;
const size_t oh = y * op.stride + op.dilation * k1 - op.padding;
for (int k2 = 0;k2 < op.kernel;k2++) {
const size_t ow = x * op.stride + k2 - op.padding;
const size_t ow = x * op.stride + op.dilation * k2 - op.padding;
*patches = (oh >= op.h_out || ow >= op.w_out) ? zero : image_out[oh * op.w_out + ow];
patches += op.h_in * op.w_in;
}
Expand All @@ -101,7 +103,7 @@ __device__ void transpose_filters(
const Conv2DOp op,
const T *filters, // 4d (ChanOut, ChanIn, KernelSize, KernelSize)
const size_t *strides, // 4d filters strides
T *filters_tr // 4d (ChanIn, ChanOut, KernelSize, KernelSize)
T *filters_tr // 5d (Groups, ChanIn, ChanOut/Groups, KernelSize, KernelSize)
) {
unsigned int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i >= op.chan_in * op.chan_out * op.kernel * op.kernel) {
Expand All @@ -113,19 +115,25 @@ __device__ void transpose_filters(
idx /= op.kernel;
const size_t k1 = idx % op.kernel;
idx /= op.kernel;
const size_t o = idx % op.chan_out;
idx /= op.chan_out;
const size_t c = idx % op.chan_in;
idx /= op.chan_in;
const size_t o = idx % op.chan_out;
const size_t og = o % o_per_g;
const size_t g = o / o_per_g;

auto i_no = o * strides[0] + c * strides[1] + k1 * strides[2] + k2 * strides[3];

filters_tr[i] = filters[i_no];
filters_tr += k2;
filters_tr += k1 * op.kernel;
filters_tr += og * (op.kernel * op.kernel);
filters_tr += c * (o_per_g * op.kernel * op.kernel);
filters_tr += g * (op.chan_in * o_per_g * op.kernel * op.kernel);
*filters_tr = filters[i_no];
}

template<typename T>
__device__ void sum_transposed_filters(
const Conv2DOp op,
const T *filters_tr, // 5d (Batch, ChanIn, ChanOut, KernelSize, KernelSize)
const T *filters_tr, // 6d (Batch, Groups, ChanIn, ChanOut/Groups, KernelSize, KernelSize)
T *filters, // 4d (ChanOut, ChanIn, KernelSize, KernelSize)
const size_t *strides // 4d filter strides
) {
Expand All @@ -135,6 +143,8 @@ __device__ void sum_transposed_filters(
return;
}

const size_t o_per_g = op.chan_out / op.groups;

unsigned int idx = i;
const size_t k2 = idx % op.kernel;
idx /= op.kernel;
Expand All @@ -143,12 +153,17 @@ __device__ void sum_transposed_filters(
const size_t c = idx % op.chan_in;
idx /= op.chan_in;
const size_t o = idx % op.chan_out;
idx /= op.chan_out;
const size_t og = o % o_per_g;
const size_t g = o / o_per_g;

auto i_tr = c * (op.chan_out * op.kernel * op.kernel) + o * (op.kernel * op.kernel) + k1 * (op.kernel) + k2;
auto i_no = o * strides[0] + c * strides[1] + k1 * strides[2] + k2 * strides[3];

filters_tr += i_tr;
filters_tr += k2;
filters_tr += k1 * op.kernel;
filters_tr += og * (op.kernel * op.kernel);
filters_tr += c * (o_per_g * op.kernel * op.kernel);
filters_tr += g * (op.chan_in * o_per_g * op.kernel * op.kernel);

T tmp = 0.0;
for (int b = 0; b < op.batch; b++) {
Expand Down
134 changes: 97 additions & 37 deletions src/tensor_ops/convtrans2d/cuda_kernel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,11 @@ where
Self: HasCudaKernel<E>,
CudaBlas: Gemm<E>,
{
fn alloc<S: Shape>(&self, shape: S) -> Result<Tensor<S, E, Self>, Self::Err> {
let data = unsafe { self.alloc_empty::<E>(shape.num_elements()) }?;
Ok(self.build_tensor(shape, shape.strides(), data))
}

fn forward<L: Shape, R: Shape, O: Shape>(
&self,
op: super::ConvTrans2DOp,
Expand All @@ -72,32 +77,54 @@ where
self.dev.load_ptx(PTX_SRC.into(), Self::MOD, Self::FNS)?;
}

let patches_numel = op.batch * op.chan_in * op.kernel * op.kernel * op.h_out * op.w_out;
let patches_numel =
op.batch * op.groups * op.chan_in * op.kernel * op.kernel * op.h_out * op.w_out;
let mut patches = unsafe { self.get_workspace::<E>(patches_numel) }?;
let mut patches = unsafe { patches.transmute_mut::<E>(patches_numel).unwrap() };

let img_strides = self.dev.htod_copy(make_4d::<L>(lhs.strides).into())?;
let unfold_fn = self.dev.get_func(Self::MOD, Self::FNS[0]).unwrap();
let cfg = launch_cfg::<128>((op.batch * op.chan_in * op.h_out * op.w_out) as u32);
let cfg =
launch_cfg::<128>((op.batch * op.groups * op.chan_in * op.h_out * op.w_out) as u32);
let params = (op, lhs.data.as_ref(), &img_strides, &mut patches);
unsafe { unfold_fn.launch(cfg, params) }?;

// (O, C * K * K) * (B, C * K * K, OH * OW) = (B, O, OH * OW)
let m = op.chan_out;
let out_buf = Arc::get_mut(&mut out.data).unwrap();

// LHS (G, O/G, C*K*K)
// RHS (B, G, C*K*K, OH*OW)
// OUT (B, G, O/G, OH*OW)
let m = op.chan_out / op.groups;
let k = op.chan_in * op.kernel * op.kernel;
let n = op.h_out * op.w_out;
unsafe {
self.gemm_batch(
(op.batch, m, k, n),
rhs.data.as_ref(),
[0, k, 1],
&patches,
[k * n, n, 1],
Default::default(),
Arc::get_mut(&mut out.data).unwrap(),
[m * n, n, 1],
)
.unwrap();
if op.groups == 1 {
self.gemm_batch(
(op.batch, m, k, n),
rhs.data.as_ref(),
[0, k, 1],
&patches,
[k * n, n, 1],
Default::default(),
out_buf,
[m * n, n, 1],
)
.unwrap();
} else {
for i_batch in 0..op.batch {
self.gemm_batch(
(op.groups, m, k, n),
rhs.data.as_ref(),
[m * k, k, 1],
&patches.slice(i_batch * op.groups * k * n..),
[k * n, n, 1],
Default::default(),
&mut out_buf.slice_mut(i_batch * op.groups * m * n..),
[m * n, n, 1],
)
.unwrap();
}
}
}

Ok(())
Expand All @@ -119,8 +146,8 @@ where
let mut patches = unsafe { self.get_workspace::<E>(patches_numel) }?;
let mut patches = unsafe { patches.transmute_mut::<E>(patches_numel).unwrap() };

let mut f_b1023 = unsafe { self.alloc_empty::<E>(filters_numel) }?;
let mut grad_f_b1023 = unsafe { self.alloc_empty::<E>(op.batch * filters_numel) }?;
let mut ftr = unsafe { self.alloc_empty::<E>(filters_numel) }?;
let mut grad_ftr = unsafe { self.alloc_empty::<E>(op.batch * filters_numel) }?;
let f_strides = self.dev.htod_copy(rhs.strides.into())?;

self.par_stream.wait_for_default()?;
Expand All @@ -132,31 +159,32 @@ where
unsafe { unfold_fn.launch(cfg, (op, grad_out, &mut patches)) }?;
}

{
unsafe {
// prepare filters for backward operations by
// swapping dims 0 and 1 and adding a batch dimension
let tr_fn = self.dev.get_func(Self::MOD, Self::FNS[2]).unwrap();
let cfg = launch_cfg::<128>(rhs.shape.num_elements() as u32);
unsafe {
tr_fn.launch_on_stream(
self.par_stream.as_ref(),
cfg,
(op, rhs.data.as_ref(), &f_strides, &mut f_b1023),
)
}?;
tr_fn.launch_on_stream(
self.par_stream.as_ref(),
cfg,
(op, rhs.data.as_ref(), &f_strides, &mut ftr),
)?;

self.par_stream.wait_for_default()?;

// img_g += filters * patches
// (B, C, H * W) += (B, C, O * K * K) * (B, O * K * K, H * W)
// LHS = (G, C, O/G*K*K)
// RHS = (B, G, O/G*K*K, H*W)
// OUT = (B, G, C, H*W)
let m = op.chan_in;
let k = op.chan_out * op.kernel * op.kernel;
let k = (op.chan_out / op.groups) * op.kernel * op.kernel;
let n = op.h_in * op.w_in;
unsafe {
self.blas.set_stream(Some(self.par_stream.as_ref()))?;
self.blas.set_stream(Some(self.par_stream.as_ref()))?;
if op.groups == 1 {
// optimizing here for common case
self.gemm_batch(
(op.batch, m, k, n),
&f_b1023,
&ftr,
[0, k, 1],
&patches,
[k * n, n, 1],
Expand All @@ -165,35 +193,67 @@ where
[m * n, n, 1],
)
.unwrap();
self.blas.set_stream(None)?;
} else {
for i_batch in 0..op.batch {
self.gemm_batch(
(op.groups, m, k, n),
&ftr,
[m * k, k, 1],
&patches.slice(i_batch * op.groups * k * n..),
[k * n, n, 1],
<E>::ONE,
&mut grad_lhs.slice_mut(i_batch * op.groups * m * n..),
[m * n, n, 1],
)
.unwrap();
}
}
self.blas.set_stream(None)?;
}

{
unsafe {
// weight_g += img * patches^T
// (B, C, O * K * K) += (B, C, H * W) * (B, H * W, O * K * K)
// LHS = (B, G, C, H*W)
// RHS = (B, H*W, G, O/G*K*K)
// OUT = (B, G, C, O/G*K*K)
let m = op.chan_in;
let k = op.h_in * op.w_in;
let n = op.chan_out * op.kernel * op.kernel;
unsafe {
let n = (op.chan_out / op.groups) * op.kernel * op.kernel;
if op.groups == 1 {
// optimizing here for common case
self.gemm_batch(
(op.batch, m, k, n),
lhs.data.as_ref(),
[m * k, k, 1],
&patches,
[k * n, 1, k],
Default::default(),
&mut grad_f_b1023,
&mut grad_ftr,
[m * n, n, 1],
)
.unwrap();
} else {
let lhs_buf = lhs.data.as_ref();
for i_batch in 0..op.batch {
self.gemm_batch(
(op.groups, m, k, n),
&lhs_buf.slice(i_batch * op.groups * m * k..),
[m * k, k, 1],
&patches.slice(i_batch * op.groups * k * n..),
[k * n, 1, k],
Default::default(),
&mut grad_ftr.slice_mut(i_batch * op.groups * m * n..),
[m * n, n, 1],
)
.unwrap();
}
}

// sum all the gradients collected in our broadcasted grad_f
// into grad_rhs
let sum_fn = self.dev.get_func(Self::MOD, Self::FNS[3]).unwrap();
let cfg = launch_cfg::<128>(rhs.shape.num_elements() as u32);
unsafe { sum_fn.launch(cfg, (op, &grad_f_b1023, grad_rhs, &f_strides)) }?;
sum_fn.launch(cfg, (op, &grad_ftr, grad_rhs, &f_strides))?;
}

self.dev.wait_for(self.par_stream.as_ref())?;
Expand Down

0 comments on commit 71fcada

Please sign in to comment.