Skip to content

Commit

Permalink
Fixing compilation errors
Browse files Browse the repository at this point in the history
  • Loading branch information
coreylowman committed Jul 5, 2023
1 parent 2c24bd8 commit 0a18260
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 6 deletions.
4 changes: 2 additions & 2 deletions src/tensor_ops/convtrans2d/convtrans2d.cu
Original file line number Diff line number Diff line change
Expand Up @@ -129,11 +129,11 @@ __device__ void transpose_filters(
filters_tr += k1 * op.kernel;
filters_tr += cg * (op.kernel * op.kernel);
filters_tr += og * (c_per_g * op.kernel * op.kernel);
filters_tr += g * (o_per_g * * c_per_g * op.kernel * op.kernel);
filters_tr += g * (o_per_g * c_per_g * op.kernel * op.kernel);
*filters_tr = filters[i_no];
}

#define CONV_OP(TYPENAME, UNFOLD_INPUT, UNFOLD_OUTPUT, TR_FILTERS, SUM_TR_FILTERS) \
#define CONV_OP(TYPENAME, UNFOLD_INPUT, UNFOLD_OUTPUT, TR_FILTERS) \
extern "C" __global__ void UNFOLD_INPUT( \
const Conv2DOp op, \
const TYPENAME *image, \
Expand Down
1 change: 0 additions & 1 deletion src/tensor_ops/convtrans2d/cpu_kernel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,6 @@ where
let lhs = lhs.data.as_ref();

let rhs = rhs.data.as_ref();
let grad_rhs = grad_rhs.data.as_mut();
for i_batch in 0..op.batch {
self.convtrans2d_backward(
&op,
Expand Down
5 changes: 2 additions & 3 deletions src/tensor_ops/convtrans2d/cuda_kernel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ where
// generate patches for matmul
let unfold_fn = self.dev.get_func(Self::MOD, Self::FNS[0]).unwrap();
let cfg = launch_cfg::<128>((op.batch * op.chan_in * op.h_out * op.w_out) as u32);
unsafe { unfold_fn.launch(cfg, (op, lhs.data.as_ref(), &img_strides, &mut patches)) }?;
unfold_fn.launch(cfg, (op, lhs.data.as_ref(), &img_strides, &mut patches))?;

// prepare filters for backward operations by
// swapping dims 0 and 1 and adding a batch dimension
Expand Down Expand Up @@ -163,7 +163,6 @@ where
}

let rhs = rhs.data.as_ref();
let grad_rhs = grad_rhs.data.as_mut();

unsafe {
self.par_stream.wait_for_default()?;
Expand Down Expand Up @@ -238,7 +237,7 @@ where
&patches.slice(i_batch * op.groups * k * n..),
[k * n, 1, k],
Default::default(),
grad_rhs.slice_mut(i_batch * op.groups * m * n..),
&mut grad_rhs.slice_mut(i_batch * op.groups * m * n..),
[m * n, n, 1],
)
.unwrap();
Expand Down

0 comments on commit 0a18260

Please sign in to comment.