From 71fcada9d83d0df59eb0e673b9bb07ab1782fadc Mon Sep 17 00:00:00 2001 From: Corey Lowman Date: Fri, 12 May 2023 11:50:46 -0400 Subject: [PATCH] Updating cuda kernel for conv transpose --- src/tensor_ops/convtrans2d/convtrans2d.cu | 55 +++++---- src/tensor_ops/convtrans2d/cuda_kernel.rs | 134 ++++++++++++++++------ 2 files changed, 132 insertions(+), 57 deletions(-) diff --git a/src/tensor_ops/convtrans2d/convtrans2d.cu b/src/tensor_ops/convtrans2d/convtrans2d.cu index d6e842704..86eadf9f0 100644 --- a/src/tensor_ops/convtrans2d/convtrans2d.cu +++ b/src/tensor_ops/convtrans2d/convtrans2d.cu @@ -1,9 +1,11 @@ #include "cuda_fp16.h" struct Conv2DOp { + size_t kernel; size_t stride; size_t padding; - size_t kernel; + size_t dilation; + size_t groups; size_t batch; size_t chan_in; size_t chan_out; @@ -16,12 +18,12 @@ struct Conv2DOp { template __device__ void unfold_input_into_patches( const Conv2DOp op, - const T *image, // 4d (Batch, Channels, Height, Width) + const T *image, // 4d (Batch, Groups * Channels, Height, Width) const size_t *strides, // 4d image strides - T *patches // 6d (Batch, Channels, KernelSize, KernelSize, HeightOut, WidthOut) + T *patches // 6d (Batch, Groups * Channels, KernelSize, KernelSize, HeightOut, WidthOut) ) { unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; - if (i >= op.batch * op.chan_in * op.h_out * op.w_out) { + if (i >= op.batch * op.groups * op.chan_in * op.h_out * op.w_out) { return; } @@ -30,8 +32,8 @@ __device__ void unfold_input_into_patches( idx /= op.w_out; const size_t oh = idx % op.h_out; idx /= op.h_out; - const size_t c = idx % op.chan_in; - idx /= op.chan_in; + const size_t c = idx % (op.chan_in * op.groups); + idx /= (op.chan_in * op.groups); const size_t b = idx % op.batch; idx /= op.batch; @@ -44,15 +46,15 @@ __device__ void unfold_input_into_patches( for (int k1 = 0;k1 < op.kernel;k1++) { const size_t y_ks = oh + op.padding; - const size_t y_s = y_ks - k1; + const size_t y_s = y_ks - op.dilation * k1; const size_t y = y_s / op.stride; - const bool k1_invalid = (y_ks < k1 || y_s % op.stride != 0 || y >= op.h_in); + const bool k1_invalid = (y_ks < op.dilation * k1 || y_s % op.stride != 0 || y >= op.h_in); for (int k2 = 0;k2 < op.kernel;k2++) { const size_t x_ks = ow + op.padding; - const size_t x_s = x_ks - k2; + const size_t x_s = x_ks - op.dilation * k2; const size_t x = x_s / op.stride; - const bool invalid = k1_invalid || (x_ks < k2 || x_s % op.stride != 0 || x >= op.w_in); + const bool invalid = k1_invalid || (x_ks < op.dilation * k2 || x_s % op.stride != 0 || x >= op.w_in); *patches = invalid ? zero : image[y * strides[2] + x * strides[3]]; patches += op.h_out * op.w_out; } @@ -87,9 +89,9 @@ __device__ void unfold_output_into_patches( T zero = 0.0; for (int k1 = 0;k1 < op.kernel;k1++) { - const size_t oh = y * op.stride + k1 - op.padding; + const size_t oh = y * op.stride + op.dilation * k1 - op.padding; for (int k2 = 0;k2 < op.kernel;k2++) { - const size_t ow = x * op.stride + k2 - op.padding; + const size_t ow = x * op.stride + op.dilation * k2 - op.padding; *patches = (oh >= op.h_out || ow >= op.w_out) ? zero : image_out[oh * op.w_out + ow]; patches += op.h_in * op.w_in; } @@ -101,7 +103,7 @@ __device__ void transpose_filters( const Conv2DOp op, const T *filters, // 4d (ChanOut, ChanIn, KernelSize, KernelSize) const size_t *strides, // 4d filters strides - T *filters_tr // 4d (ChanIn, ChanOut, KernelSize, KernelSize) + T *filters_tr // 5d (Groups, ChanIn, ChanOut/Groups, KernelSize, KernelSize) ) { unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; if (i >= op.chan_in * op.chan_out * op.kernel * op.kernel) { @@ -113,19 +115,25 @@ __device__ void transpose_filters( idx /= op.kernel; const size_t k1 = idx % op.kernel; idx /= op.kernel; - const size_t o = idx % op.chan_out; - idx /= op.chan_out; const size_t c = idx % op.chan_in; + idx /= op.chan_in; + const size_t o = idx % op.chan_out; + const size_t og = o % o_per_g; + const size_t g = o / o_per_g; auto i_no = o * strides[0] + c * strides[1] + k1 * strides[2] + k2 * strides[3]; - - filters_tr[i] = filters[i_no]; + filters_tr += k2; + filters_tr += k1 * op.kernel; + filters_tr += og * (op.kernel * op.kernel); + filters_tr += c * (o_per_g * op.kernel * op.kernel); + filters_tr += g * (op.chan_in * o_per_g * op.kernel * op.kernel); + *filters_tr = filters[i_no]; } template __device__ void sum_transposed_filters( const Conv2DOp op, - const T *filters_tr, // 5d (Batch, ChanIn, ChanOut, KernelSize, KernelSize) + const T *filters_tr, // 6d (Batch, Groups, ChanIn, ChanOut/Groups, KernelSize, KernelSize) T *filters, // 4d (ChanOut, ChanIn, KernelSize, KernelSize) const size_t *strides // 4d filter strides ) { @@ -135,6 +143,8 @@ __device__ void sum_transposed_filters( return; } + const size_t o_per_g = op.chan_out / op.groups; + unsigned int idx = i; const size_t k2 = idx % op.kernel; idx /= op.kernel; @@ -143,12 +153,17 @@ __device__ void sum_transposed_filters( const size_t c = idx % op.chan_in; idx /= op.chan_in; const size_t o = idx % op.chan_out; - idx /= op.chan_out; + const size_t og = o % o_per_g; + const size_t g = o / o_per_g; auto i_tr = c * (op.chan_out * op.kernel * op.kernel) + o * (op.kernel * op.kernel) + k1 * (op.kernel) + k2; auto i_no = o * strides[0] + c * strides[1] + k1 * strides[2] + k2 * strides[3]; - filters_tr += i_tr; + filters_tr += k2; + filters_tr += k1 * op.kernel; + filters_tr += og * (op.kernel * op.kernel); + filters_tr += c * (o_per_g * op.kernel * op.kernel); + filters_tr += g * (op.chan_in * o_per_g * op.kernel * op.kernel); T tmp = 0.0; for (int b = 0; b < op.batch; b++) { diff --git a/src/tensor_ops/convtrans2d/cuda_kernel.rs b/src/tensor_ops/convtrans2d/cuda_kernel.rs index bb3a5a4a8..bd09bb6e9 100644 --- a/src/tensor_ops/convtrans2d/cuda_kernel.rs +++ b/src/tensor_ops/convtrans2d/cuda_kernel.rs @@ -61,6 +61,11 @@ where Self: HasCudaKernel, CudaBlas: Gemm, { + fn alloc(&self, shape: S) -> Result, Self::Err> { + let data = unsafe { self.alloc_empty::(shape.num_elements()) }?; + Ok(self.build_tensor(shape, shape.strides(), data)) + } + fn forward( &self, op: super::ConvTrans2DOp, @@ -72,32 +77,54 @@ where self.dev.load_ptx(PTX_SRC.into(), Self::MOD, Self::FNS)?; } - let patches_numel = op.batch * op.chan_in * op.kernel * op.kernel * op.h_out * op.w_out; + let patches_numel = + op.batch * op.groups * op.chan_in * op.kernel * op.kernel * op.h_out * op.w_out; let mut patches = unsafe { self.get_workspace::(patches_numel) }?; let mut patches = unsafe { patches.transmute_mut::(patches_numel).unwrap() }; let img_strides = self.dev.htod_copy(make_4d::(lhs.strides).into())?; let unfold_fn = self.dev.get_func(Self::MOD, Self::FNS[0]).unwrap(); - let cfg = launch_cfg::<128>((op.batch * op.chan_in * op.h_out * op.w_out) as u32); + let cfg = + launch_cfg::<128>((op.batch * op.groups * op.chan_in * op.h_out * op.w_out) as u32); let params = (op, lhs.data.as_ref(), &img_strides, &mut patches); unsafe { unfold_fn.launch(cfg, params) }?; - // (O, C * K * K) * (B, C * K * K, OH * OW) = (B, O, OH * OW) - let m = op.chan_out; + let out_buf = Arc::get_mut(&mut out.data).unwrap(); + + // LHS (G, O/G, C*K*K) + // RHS (B, G, C*K*K, OH*OW) + // OUT (B, G, O/G, OH*OW) + let m = op.chan_out / op.groups; let k = op.chan_in * op.kernel * op.kernel; let n = op.h_out * op.w_out; unsafe { - self.gemm_batch( - (op.batch, m, k, n), - rhs.data.as_ref(), - [0, k, 1], - &patches, - [k * n, n, 1], - Default::default(), - Arc::get_mut(&mut out.data).unwrap(), - [m * n, n, 1], - ) - .unwrap(); + if op.groups == 1 { + self.gemm_batch( + (op.batch, m, k, n), + rhs.data.as_ref(), + [0, k, 1], + &patches, + [k * n, n, 1], + Default::default(), + out_buf, + [m * n, n, 1], + ) + .unwrap(); + } else { + for i_batch in 0..op.batch { + self.gemm_batch( + (op.groups, m, k, n), + rhs.data.as_ref(), + [m * k, k, 1], + &patches.slice(i_batch * op.groups * k * n..), + [k * n, n, 1], + Default::default(), + &mut out_buf.slice_mut(i_batch * op.groups * m * n..), + [m * n, n, 1], + ) + .unwrap(); + } + } } Ok(()) @@ -119,8 +146,8 @@ where let mut patches = unsafe { self.get_workspace::(patches_numel) }?; let mut patches = unsafe { patches.transmute_mut::(patches_numel).unwrap() }; - let mut f_b1023 = unsafe { self.alloc_empty::(filters_numel) }?; - let mut grad_f_b1023 = unsafe { self.alloc_empty::(op.batch * filters_numel) }?; + let mut ftr = unsafe { self.alloc_empty::(filters_numel) }?; + let mut grad_ftr = unsafe { self.alloc_empty::(op.batch * filters_numel) }?; let f_strides = self.dev.htod_copy(rhs.strides.into())?; self.par_stream.wait_for_default()?; @@ -132,31 +159,32 @@ where unsafe { unfold_fn.launch(cfg, (op, grad_out, &mut patches)) }?; } - { + unsafe { // prepare filters for backward operations by // swapping dims 0 and 1 and adding a batch dimension let tr_fn = self.dev.get_func(Self::MOD, Self::FNS[2]).unwrap(); let cfg = launch_cfg::<128>(rhs.shape.num_elements() as u32); - unsafe { - tr_fn.launch_on_stream( - self.par_stream.as_ref(), - cfg, - (op, rhs.data.as_ref(), &f_strides, &mut f_b1023), - ) - }?; + tr_fn.launch_on_stream( + self.par_stream.as_ref(), + cfg, + (op, rhs.data.as_ref(), &f_strides, &mut ftr), + )?; self.par_stream.wait_for_default()?; // img_g += filters * patches - // (B, C, H * W) += (B, C, O * K * K) * (B, O * K * K, H * W) + // LHS = (G, C, O/G*K*K) + // RHS = (B, G, O/G*K*K, H*W) + // OUT = (B, G, C, H*W) let m = op.chan_in; - let k = op.chan_out * op.kernel * op.kernel; + let k = (op.chan_out / op.groups) * op.kernel * op.kernel; let n = op.h_in * op.w_in; - unsafe { - self.blas.set_stream(Some(self.par_stream.as_ref()))?; + self.blas.set_stream(Some(self.par_stream.as_ref()))?; + if op.groups == 1 { + // optimizing here for common case self.gemm_batch( (op.batch, m, k, n), - &f_b1023, + &ftr, [0, k, 1], &patches, [k * n, n, 1], @@ -165,17 +193,34 @@ where [m * n, n, 1], ) .unwrap(); - self.blas.set_stream(None)?; + } else { + for i_batch in 0..op.batch { + self.gemm_batch( + (op.groups, m, k, n), + &ftr, + [m * k, k, 1], + &patches.slice(i_batch * op.groups * k * n..), + [k * n, n, 1], + ::ONE, + &mut grad_lhs.slice_mut(i_batch * op.groups * m * n..), + [m * n, n, 1], + ) + .unwrap(); + } } + self.blas.set_stream(None)?; } - { + unsafe { // weight_g += img * patches^T - // (B, C, O * K * K) += (B, C, H * W) * (B, H * W, O * K * K) + // LHS = (B, G, C, H*W) + // RHS = (B, H*W, G, O/G*K*K) + // OUT = (B, G, C, O/G*K*K) let m = op.chan_in; let k = op.h_in * op.w_in; - let n = op.chan_out * op.kernel * op.kernel; - unsafe { + let n = (op.chan_out / op.groups) * op.kernel * op.kernel; + if op.groups == 1 { + // optimizing here for common case self.gemm_batch( (op.batch, m, k, n), lhs.data.as_ref(), @@ -183,17 +228,32 @@ where &patches, [k * n, 1, k], Default::default(), - &mut grad_f_b1023, + &mut grad_ftr, [m * n, n, 1], ) .unwrap(); + } else { + let lhs_buf = lhs.data.as_ref(); + for i_batch in 0..op.batch { + self.gemm_batch( + (op.groups, m, k, n), + &lhs_buf.slice(i_batch * op.groups * m * k..), + [m * k, k, 1], + &patches.slice(i_batch * op.groups * k * n..), + [k * n, 1, k], + Default::default(), + &mut grad_ftr.slice_mut(i_batch * op.groups * m * n..), + [m * n, n, 1], + ) + .unwrap(); + } } // sum all the gradients collected in our broadcasted grad_f // into grad_rhs let sum_fn = self.dev.get_func(Self::MOD, Self::FNS[3]).unwrap(); let cfg = launch_cfg::<128>(rhs.shape.num_elements() as u32); - unsafe { sum_fn.launch(cfg, (op, &grad_f_b1023, grad_rhs, &f_strides)) }?; + sum_fn.launch(cfg, (op, &grad_ftr, grad_rhs, &f_strides))?; } self.dev.wait_for(self.par_stream.as_ref())?;