diff --git a/build.rs b/build.rs index 6f19b61e5..1048382a4 100644 --- a/build.rs +++ b/build.rs @@ -123,7 +123,7 @@ mod cuda { .arg("--query-gpu=compute_cap") .arg("--format=csv") .output() - .unwrap(); + .expect("`nvidia-smi` failed. Ensure that you have CUDA installed and that `nvidia-smi` is in your PATH."); let out = std::str::from_utf8(&out.stdout).unwrap(); let mut lines = out.lines(); assert_eq!(lines.next().unwrap(), "compute_cap"); @@ -136,7 +136,7 @@ mod cuda { let out = std::process::Command::new("nvcc") .arg("--list-gpu-code") .output() - .unwrap(); + .expect("`nvcc` failed. Ensure that you have CUDA installed and that `nvcc` is in your PATH."); let out = std::str::from_utf8(&out.stdout).unwrap(); let out = out.lines().collect::>(); @@ -188,12 +188,12 @@ mod cuda { .stdout(std::process::Stdio::piped()) .stderr(std::process::Stdio::piped()) .spawn() - .unwrap() + .expect("nvcc failed to start. Ensure that you have CUDA installed and that `nvcc` is in your PATH.") }) .collect::>(); for (kernel_path, child) in kernel_paths.iter().zip(children.into_iter()) { - let output = child.wait_with_output().unwrap(); + let output = child.wait_with_output().expect("nvcc failed to run. Ensure that you have CUDA installed and that `nvcc` is in your PATH."); assert!( output.status.success(), "nvcc error while compiling {kernel_path:?}:\n\n# stdout\n{:#}\n\n# stderr\n{:#}", diff --git a/src/nn/conv.rs b/src/nn/conv.rs index af21a6712..c53fe2d43 100644 --- a/src/nn/conv.rs +++ b/src/nn/conv.rs @@ -32,6 +32,7 @@ impl< where E: Dtype, D: Device, + Const<{ I / G }>: Sized, Conv2D: BuildModule, { type Built = Conv2D; @@ -62,6 +63,7 @@ where /// /// See [conv animations](https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md) for helpful /// visualization of all of these parameters. + #[derive(Debug, Clone)] pub struct Conv2D< const IN_CHAN: usize, @@ -73,8 +75,10 @@ pub struct Conv2D< const GROUPS: usize, E: Dtype, D: Storage, -> { - pub weight: Tensor, E, D>, +> where + Const<{ IN_CHAN / GROUPS }>: Sized, +{ + pub weight: Tensor, E, D>, } impl< @@ -89,6 +93,7 @@ impl< D, > TensorCollection for Conv2D where + Const<{ I / G }>: Sized, E: Dtype + Float + SampleUniform, D: Device, { @@ -112,9 +117,8 @@ where } } -#[cfg(feature = "nightly")] impl< - const C: usize, + const I: usize, const O: usize, const K: usize, const S: usize, @@ -124,19 +128,21 @@ impl< E, D, Img, - > Module for Conv2D + > Module for Conv2D where + Const<{ I / G }>: Sized, E: Dtype, D: Device, - (Img, Tensor, E, D>): TryConv2D, Const

, Const, Const>, + (Img, Tensor, E, D>): + TryConv2D, Const

, Const, Const>, { - type Output = <(Img, Tensor, E, D>) as TryConv2D< + type Output = <(Img, Tensor, E, D>) as TryConv2D< Const, Const

, Const, Const, >>::Convolved; - type Error = <(Img, Tensor, E, D>) as TryConv2D< + type Error = <(Img, Tensor, E, D>) as TryConv2D< Const, Const

, Const, @@ -159,10 +165,11 @@ impl< E: Dtype, D: Storage, > NonMutableModule for Conv2D +where + Const<{ I / G }>: Sized, { } -#[cfg(feature = "nightly")] #[cfg(test)] mod tests { use crate::{ @@ -189,6 +196,33 @@ mod tests { let _: Tensor, _, _, _> = dev.build_module::, TestDtype>().forward(x.clone()); } + #[test] + fn test_grouped_forward_sizes() { + let dev: TestDevice = Default::default(); + + let x = dev.zeros::>(); + + let m = dev.build_module::, TestDtype>(); + let _: Tensor, _, _> = m.weight; + let _: Tensor, _, _> = m.forward(x.clone()); + + let m = dev.build_module::, TestDtype>(); + let _: Tensor, _, _> = m.weight; + let _: Tensor, _, _> = m.forward(x.clone()); + + let m = dev.build_module::, TestDtype>(); + let _: Tensor, _, _> = m.weight; + let _: Tensor, _, _> = m.forward(x.clone()); + + let m = dev.build_module::, TestDtype>(); + let _: Tensor, _, _> = m.weight; + let _: Tensor, _, _> = m.forward(x.clone()); + + let m = dev.build_module::, TestDtype>(); + let _: Tensor, _, _> = m.weight; + let _: Tensor, _, _> = m.forward(x); + } + #[rustfmt::skip] #[test] fn test_forward_4d_sizes() { diff --git a/src/nn/mod.rs b/src/nn/mod.rs index 33e4ebdc5..ddf87b9ed 100644 --- a/src/nn/mod.rs +++ b/src/nn/mod.rs @@ -188,6 +188,7 @@ mod add_into; mod batchnorm1d; mod batchnorm2d; mod bias2d; +#[cfg(feature = "nightly")] mod conv; mod convtrans; mod dropout; diff --git a/src/optim/adam/adam.cu b/src/optim/adam/adam.cu index 3b1dcf9e0..b5ee7268a 100644 --- a/src/optim/adam/adam.cu +++ b/src/optim/adam/adam.cu @@ -25,42 +25,37 @@ __device__ void adam_update( T* moment2, const T* grad ) { - unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; - - if (i >= numel) { - return; - } - T beta1 = cfg.beta1; T beta2 = cfg.beta2; T lr = cfg.lr; T weight_decay = cfg.weight_decay; T eps = cfg.eps; - - T p = param[i]; - T g = grad[i]; - T m = moment1[i]; - T v = moment2[i]; T one = 1.0; T t = t_int; - - if (cfg.weight_decay_type == L2) { - g += weight_decay * p; + for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { + T p = param[i]; + T g = grad[i]; + T m = moment1[i]; + T v = moment2[i]; + + if (cfg.weight_decay_type == L2) { + g += weight_decay * p; + } + + m = m * beta1 + g * (one - beta1); + v = v * beta2 + g * g * (one - beta2); + T m_hat = m * one / (one - powg(beta1, t)); + T v_hat = v * one / (one - powg(beta2, t)); + g = lr * m_hat / (sqrtg(v_hat) + eps); + + if (cfg.weight_decay_type == Decoupled) { + g += (weight_decay * lr) * p; + } + + moment1[i] = m; + moment2[i] = v; + param[i] -= g; } - - m = m * beta1 + g * (one - beta1); - v = v * beta2 + g * g * (one - beta2); - T m_hat = m * one / (one - powg(beta1, t)); - T v_hat = v * one / (one - powg(beta2, t)); - g = lr * m_hat / (sqrtg(v_hat) + eps); - - if (cfg.weight_decay_type == Decoupled) { - g += (weight_decay * lr) * p; - } - - moment1[i] = m; - moment2[i] = v; - param[i] -= g; } #define ADAM(TYPENAME, FN) \ diff --git a/src/optim/rmsprop/rmsprop.cu b/src/optim/rmsprop/rmsprop.cu index 0beb5b4bb..acdfe06ef 100644 --- a/src/optim/rmsprop/rmsprop.cu +++ b/src/optim/rmsprop/rmsprop.cu @@ -27,58 +27,55 @@ __device__ void rmsprop_update( T* grad_avg, const T* grad ) { - unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; - - if (i >= numel) { - return; - } - T lr = cfg.lr; T alpha = cfg.alpha; T eps = cfg.eps; T momentum_ = cfg.momentum; T weight_decay = cfg.weight_decay; - - T p = param[i]; - T g = grad[i]; - T s_avg = square_avg[i]; - T g_avg = grad_avg[i]; - T m = momentum[i]; T one = 1.0; - if (cfg.weight_decay_type == L2) { - g += weight_decay * p; + for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { + T p = param[i]; + T g = grad[i]; + T s_avg = square_avg[i]; + T g_avg = grad_avg[i]; + T m = momentum[i]; + + + if (cfg.weight_decay_type == L2) { + g += weight_decay * p; + } + + s_avg += (one - alpha) * (g * g - s_avg); + + T avg; + + if (cfg.centered) { + // ga = a * ga + (1 - a) * g + g_avg += (one - alpha) * (g - g_avg); + avg = sqrtg(s_avg - g_avg * g_avg + eps); + } else { + avg = sqrtg(s_avg + eps); + }; + + g /= avg; + + if (cfg.has_momentum) { + m = m * momentum_ + g; + g = m * lr; + } else { + g *= lr; + } + + if (cfg.weight_decay_type == Decoupled) { + g += weight_decay * lr * p; + } + + square_avg[i] = s_avg; + grad_avg[i] = g_avg; + momentum[i] = m; + param[i] -= g; } - - s_avg += (one - alpha) * (g * g - s_avg); - - T avg; - - if (cfg.centered) { - // ga = a * ga + (1 - a) * g - g_avg += (one - alpha) * (g - g_avg); - avg = sqrtg(s_avg - g_avg * g_avg + eps); - } else { - avg = sqrtg(s_avg + eps); - }; - - g /= avg; - - if (cfg.has_momentum) { - m = m * momentum_ + g; - g = m * lr; - } else { - g *= lr; - } - - if (cfg.weight_decay_type == Decoupled) { - g += weight_decay * lr * p; - } - - square_avg[i] = s_avg; - grad_avg[i] = g_avg; - momentum[i] = m; - param[i] -= g; } #define RMSPROP(TYPENAME, FN) \ diff --git a/src/optim/sgd/sgd.cu b/src/optim/sgd/sgd.cu index 226930011..2c666cb1d 100644 --- a/src/optim/sgd/sgd.cu +++ b/src/optim/sgd/sgd.cu @@ -28,40 +28,36 @@ __device__ void sgd_update( T* velocity, const T* grad ) { - unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; - - if (i >= numel) { - return; - } - T weight_decay = cfg.weight_decay; T lr = cfg.lr; T momentum = cfg.momentum; - T p = param[i]; - T g = grad[i]; - T v = velocity[i]; - - if (cfg.weight_decay_type == L2) { - g += weight_decay * p; + for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { + T p = param[i]; + T g = grad[i]; + T v = velocity[i]; + + if (cfg.weight_decay_type == L2) { + g += weight_decay * p; + } + + if (cfg.momentum_type == Classic) { + v = g + momentum * v; + g = v * lr; + } else if (cfg.momentum_type == Nesterov) { + v = g + momentum * v; + g = (g + momentum * v) * lr; + } else { + g *= lr; + } + + if (cfg.weight_decay_type == Decoupled) { + g += weight_decay * lr * p; + } + + velocity[i] = v; + param[i] -= g; } - - if (cfg.momentum_type == Classic) { - v = g + momentum * v; - g = v * lr; - } else if (cfg.momentum_type == Nesterov) { - v = g + momentum * v; - g = (g + momentum * v) * lr; - } else { - g *= lr; - } - - if (cfg.weight_decay_type == Decoupled) { - g += weight_decay * lr * p; - } - - velocity[i] = v; - param[i] -= g; } #define SGD(TYPENAME, FN) \ diff --git a/src/shapes/shape.rs b/src/shapes/shape.rs index e15915b65..ec16dbc9e 100644 --- a/src/shapes/shape.rs +++ b/src/shapes/shape.rs @@ -141,23 +141,23 @@ impl ConstDim for Const { impl core::ops::Add> for usize { type Output = usize; - fn add(self, rhs: Const) -> Self::Output { - self.size() + rhs.size() + fn add(self, _: Const) -> Self::Output { + self.size() + N } } impl core::ops::Add for Const { type Output = usize; fn add(self, rhs: usize) -> Self::Output { - self.size() + rhs.size() + N + rhs.size() } } #[cfg(feature = "nightly")] impl core::ops::Add> for Const where - Const<{ N + M }>: Sized, + Const<{ M + N }>: Sized, { - type Output = Const<{ N + M }>; + type Output = Const<{ M + N }>; fn add(self, _: Const) -> Self::Output { Const } @@ -165,28 +165,52 @@ where impl core::ops::Mul> for usize { type Output = usize; - fn mul(self, rhs: Const) -> Self::Output { - self.size() * rhs.size() + fn mul(self, _: Const) -> Self::Output { + self.size() * N } } impl core::ops::Mul for Const { type Output = usize; fn mul(self, rhs: usize) -> Self::Output { - self.size() * rhs.size() + N * rhs.size() } } #[cfg(feature = "nightly")] impl core::ops::Mul> for Const where - Const<{ N * M }>: Sized, + Const<{ M * N }>: Sized, { - type Output = Const<{ N * M }>; + type Output = Const<{ M * N }>; fn mul(self, _: Const) -> Self::Output { Const } } +impl core::ops::Div> for usize { + type Output = usize; + fn div(self, _: Const) -> Self::Output { + self.size() / N + } +} +impl core::ops::Div for Const { + type Output = usize; + fn div(self, rhs: usize) -> Self::Output { + N / rhs.size() + } +} + +#[cfg(feature = "nightly")] +impl core::ops::Div> for Const +where + Const<{ M / N }>: Sized, +{ + type Output = Const<{ M / N }>; + fn div(self, _: Const) -> Self::Output { + Const + } +} + /// Represents either `[T; N]` or `Vec` pub trait Array: IntoIterator { type Dim: Dim; diff --git a/src/tensor/mod.rs b/src/tensor/mod.rs index 12bc2b053..908249b61 100644 --- a/src/tensor/mod.rs +++ b/src/tensor/mod.rs @@ -145,6 +145,8 @@ mod gradients; mod masks; #[cfg(feature = "numpy")] pub(crate) mod numpy; +#[cfg(feature = "numpy")] +pub use numpy::NumpyDtype; #[cfg(feature = "safetensors")] pub mod safetensors; mod tensorlike; diff --git a/src/tensor_ops/axpy/axpy.cu b/src/tensor_ops/axpy/axpy.cu index 9e6907757..73d7e701f 100644 --- a/src/tensor_ops/axpy/axpy.cu +++ b/src/tensor_ops/axpy/axpy.cu @@ -2,11 +2,9 @@ template __device__ void axpy(const size_t n, T* a, const T alpha, const T* b, const T beta) { - unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; - if (i >= n) { - return; + for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; i += blockDim.x * gridDim.x) { + a[i] = a[i] * alpha + b[i] * beta; } - a[i] = a[i] * alpha + b[i] * beta; } extern "C" __global__ void axpy_f16(const size_t n, __half* a, const __half alpha, const __half* b, const __half beta) { diff --git a/src/tensor_ops/boolean/boolean.cu b/src/tensor_ops/boolean/boolean.cu index 2739f2d93..95543502b 100644 --- a/src/tensor_ops/boolean/boolean.cu +++ b/src/tensor_ops/boolean/boolean.cu @@ -12,27 +12,17 @@ extern "C" __global__ void NAME( \ const size_t *rhs_strides, \ bool *out \ ) { \ - unsigned int out_i = blockIdx.x * blockDim.x + threadIdx.x; \ - if (out_i >= numel) { \ - return; \ + for (unsigned int out_i = blockIdx.x * blockDim.x + threadIdx.x; out_i < numel; out_i += blockDim.x * gridDim.x) { \ + unsigned int lhs_i = get_strided_index(out_i, num_dims, dims, lhs_strides); \ + unsigned int rhs_i = get_strided_index(out_i, num_dims, dims, rhs_strides); \ + out[out_i] = (bool)(lhs[lhs_i]) OP (bool)(rhs[rhs_i]); \ } \ -\ - unsigned int lhs_i = get_strided_index(out_i, num_dims, dims, lhs_strides); \ - unsigned int rhs_i = get_strided_index(out_i, num_dims, dims, rhs_strides); \ -\ - out[out_i] = (bool)(lhs[lhs_i]) OP (bool)(rhs[rhs_i]); \ } -extern "C" __global__ void boolean_not( - const size_t numel, - const bool *inp, - bool *out -) { - unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; - if (i >= numel) { - return; +extern "C" __global__ void boolean_not(const size_t numel, const bool *inp, bool *out) { + for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { + out[i] = !(bool)(inp[i]); } - out[i] = !(bool)(inp[i]); } BOOLEAN_OP(boolean_and, &&); diff --git a/src/tensor_ops/choose/choose.cu b/src/tensor_ops/choose/choose.cu index 799af246f..ed57a2077 100644 --- a/src/tensor_ops/choose/choose.cu +++ b/src/tensor_ops/choose/choose.cu @@ -13,16 +13,12 @@ __device__ void choose_fwd( const size_t *rhs_strides, T *out ) { - unsigned int out_i = blockIdx.x * blockDim.x + threadIdx.x; - if (out_i >= numel) { - return; + for (unsigned int out_i = blockIdx.x * blockDim.x + threadIdx.x; out_i < numel; out_i += blockDim.x * gridDim.x) { + unsigned int lhs_i = get_strided_index(out_i, num_dims, dims, lhs_strides); + unsigned int rhs_i = get_strided_index(out_i, num_dims, dims, rhs_strides); + unsigned int cond_i = get_strided_index(out_i, num_dims, dims, cond_strides); + out[out_i] = cond[cond_i] ? lhs[lhs_i] : rhs[rhs_i]; } - - unsigned int lhs_i = get_strided_index(out_i, num_dims, dims, lhs_strides); - unsigned int rhs_i = get_strided_index(out_i, num_dims, dims, rhs_strides); - unsigned int cond_i = get_strided_index(out_i, num_dims, dims, cond_strides); - - out[out_i] = cond[cond_i] ? lhs[lhs_i] : rhs[rhs_i]; } template @@ -38,19 +34,12 @@ __device__ void choose_bwd( const size_t *rhs_strides, const T *grad_out ) { - unsigned int out_i = blockIdx.x * blockDim.x + threadIdx.x; - if (out_i >= numel) { - return; + for (unsigned int out_i = blockIdx.x * blockDim.x + threadIdx.x; out_i < numel; out_i += blockDim.x * gridDim.x) { + unsigned int lhs_i = get_strided_index(out_i, num_dims, dims, lhs_strides); + unsigned int rhs_i = get_strided_index(out_i, num_dims, dims, rhs_strides); + unsigned int cond_i = get_strided_index(out_i, num_dims, dims, cond_strides); + atomicAdd(cond[cond_i] ? grad_lhs + lhs_i : grad_rhs + rhs_i, grad_out[out_i]); } - - unsigned int lhs_i = get_strided_index(out_i, num_dims, dims, lhs_strides); - unsigned int rhs_i = get_strided_index(out_i, num_dims, dims, rhs_strides); - unsigned int cond_i = get_strided_index(out_i, num_dims, dims, cond_strides); - - auto go = grad_out[out_i]; - T* out_loc = cond[cond_i] ? grad_lhs + lhs_i : grad_rhs + rhs_i; - - atomicAdd(out_loc, go); } #define CHOOSE(TYPENAME, FWD, BWD) \ diff --git a/src/tensor_ops/cmp/cmp.cu b/src/tensor_ops/cmp/cmp.cu index bceeff44a..7fd5c68f7 100644 --- a/src/tensor_ops/cmp/cmp.cu +++ b/src/tensor_ops/cmp/cmp.cu @@ -12,14 +12,12 @@ extern "C" __global__ void FWD( \ bool *out, \ const size_t *out_strides \ ) { \ - unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; \ - if (i >= numel) { \ - return; \ + for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \ + unsigned int lhs_i = get_strided_index(i, num_dims, dims, lhs_strides); \ + unsigned int rhs_i = get_strided_index(i, num_dims, dims, rhs_strides); \ + unsigned int out_i = get_strided_index(i, num_dims, dims, out_strides); \ + out[out_i] = lhs[lhs_i] SYMBOL rhs[rhs_i]; \ } \ - unsigned int lhs_i = get_strided_index(i, num_dims, dims, lhs_strides); \ - unsigned int rhs_i = get_strided_index(i, num_dims, dims, rhs_strides); \ - unsigned int out_i = get_strided_index(i, num_dims, dims, out_strides); \ - out[out_i] = lhs[lhs_i] SYMBOL rhs[rhs_i]; \ } \ \ extern "C" __global__ void SCALAR_FWD( \ @@ -32,13 +30,11 @@ extern "C" __global__ void SCALAR_FWD( \ bool *out, \ const size_t *out_strides \ ) { \ - unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; \ - if (i >= numel) { \ - return; \ + for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \ + unsigned int lhs_i = get_strided_index(i, num_dims, dims, lhs_strides); \ + unsigned int out_i = get_strided_index(i, num_dims, dims, out_strides); \ + out[out_i] = lhs[lhs_i] SYMBOL scalar; \ } \ - unsigned int lhs_i = get_strided_index(i, num_dims, dims, lhs_strides); \ - unsigned int out_i = get_strided_index(i, num_dims, dims, out_strides); \ - out[out_i] = lhs[lhs_i] SYMBOL scalar; \ } CMP_OP(__half, eq_fwd_f16, scalar_eq_fwd_f16, ==) diff --git a/src/tensor_ops/concat/cuda_kernel.rs b/src/tensor_ops/concat/cuda_kernel.rs index 054fe1faa..147c6273b 100644 --- a/src/tensor_ops/concat/cuda_kernel.rs +++ b/src/tensor_ops/concat/cuda_kernel.rs @@ -67,7 +67,8 @@ impl super::ConcatKernel for Cuda { const BWD_KERNEL: &str = " #include \"cuda_fp16.h\" extern \"C\" __global__ void concat_bwd(const size_t numel, const $Ty *inp, $Ty *out) { - unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; - if (i < numel) { out[i] += inp[i]; } + for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { + out[i] += inp[i]; + } } "; diff --git a/src/tensor_ops/concat_along/cuda_kernel.rs b/src/tensor_ops/concat_along/cuda_kernel.rs index e9c1c0b3f..7c0d1247c 100644 --- a/src/tensor_ops/concat_along/cuda_kernel.rs +++ b/src/tensor_ops/concat_along/cuda_kernel.rs @@ -104,38 +104,37 @@ extern \"C\" __global__ void fwd( const size_t *rhs_dims = info + 3 + 2 * num_dims; const size_t *rhs_strides = info + 3 + 3 * num_dims; - unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; - if (i >= numel) { return; } - - // out_dims will be (..., lhs_dims[ax] + rhs_dims[ax], ...) - - // striding lhs & rhs up to the concat'd axis - size_t i_tmp = i; - size_t lhs_i = 0; - size_t rhs_i = 0; - for (int d = num_dims - 1; d > axis; d--) { - size_t dim_i = i_tmp % lhs_dims[d]; - lhs_i += dim_i * lhs_strides[d]; - rhs_i += dim_i * rhs_strides[d]; - i_tmp /= lhs_dims[d]; - } - - // figure out if we are using lhs or rhs for this `i` - size_t i_along_axis = i_tmp % (lhs_dims[axis] + rhs_dims[axis]); - i_tmp /= (lhs_dims[axis] + rhs_dims[axis]); - - // striding lhs & rhs along the rest of the axes - for (int d = axis - 1; d >= 0;d--) { - size_t dim_i = i_tmp % lhs_dims[d]; - lhs_i += dim_i * lhs_strides[d]; - rhs_i += dim_i * rhs_strides[d]; - i_tmp /= lhs_dims[d]; - } - - if (i_along_axis < lhs_dims[axis]) { - out[i] = lhs[lhs_i + i_along_axis * lhs_strides[axis]]; - } else { - out[i] = rhs[rhs_i + (i_along_axis - lhs_dims[axis]) * rhs_strides[axis]]; + for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { + // out_dims will be (..., lhs_dims[ax] + rhs_dims[ax], ...) + + // striding lhs & rhs up to the concat'd axis + size_t i_tmp = i; + size_t lhs_i = 0; + size_t rhs_i = 0; + for (int d = num_dims - 1; d > axis; d--) { + size_t dim_i = i_tmp % lhs_dims[d]; + lhs_i += dim_i * lhs_strides[d]; + rhs_i += dim_i * rhs_strides[d]; + i_tmp /= lhs_dims[d]; + } + + // figure out if we are using lhs or rhs for this `i` + size_t i_along_axis = i_tmp % (lhs_dims[axis] + rhs_dims[axis]); + i_tmp /= (lhs_dims[axis] + rhs_dims[axis]); + + // striding lhs & rhs along the rest of the axes + for (int d = axis - 1; d >= 0;d--) { + size_t dim_i = i_tmp % lhs_dims[d]; + lhs_i += dim_i * lhs_strides[d]; + rhs_i += dim_i * rhs_strides[d]; + i_tmp /= lhs_dims[d]; + } + + if (i_along_axis < lhs_dims[axis]) { + out[i] = lhs[lhs_i + i_along_axis * lhs_strides[axis]]; + } else { + out[i] = rhs[rhs_i + (i_along_axis - lhs_dims[axis]) * rhs_strides[axis]]; + } } } @@ -153,38 +152,37 @@ extern \"C\" __global__ void bwd( const size_t *rhs_dims = info + 3 + 2 * num_dims; const size_t *rhs_strides = info + 3 + 3 * num_dims; - unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; - if (i >= numel) { return; } - - // out_dims will be (..., lhs_dims[ax] + rhs_dims[ax], ...) - - // striding lhs & rhs up to the concat'd axis - size_t i_tmp = i; - size_t lhs_i = 0; - size_t rhs_i = 0; - for (int d = num_dims - 1; d > axis; d--) { - size_t dim_i = i_tmp % lhs_dims[d]; - lhs_i += dim_i * lhs_strides[d]; - rhs_i += dim_i * rhs_strides[d]; - i_tmp /= lhs_dims[d]; - } - - // figure out if we are using lhs or rhs for this `i` - size_t i_along_axis = i_tmp % (lhs_dims[axis] + rhs_dims[axis]); - i_tmp /= (lhs_dims[axis] + rhs_dims[axis]); - - // striding lhs & rhs along the rest of the axes - for (int d = axis - 1; d >= 0;d--) { - size_t dim_i = i_tmp % lhs_dims[d]; - lhs_i += dim_i * lhs_strides[d]; - rhs_i += dim_i * rhs_strides[d]; - i_tmp /= lhs_dims[d]; - } - - if (i_along_axis < lhs_dims[axis]) { - atomicAdd(grad_lhs + lhs_i + i_along_axis * lhs_strides[axis], grad_out[i]); - } else { - atomicAdd(grad_rhs + rhs_i + (i_along_axis - lhs_dims[axis]) * rhs_strides[axis], grad_out[i]); + for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { + // out_dims will be (..., lhs_dims[ax] + rhs_dims[ax], ...) + + // striding lhs & rhs up to the concat'd axis + size_t i_tmp = i; + size_t lhs_i = 0; + size_t rhs_i = 0; + for (int d = num_dims - 1; d > axis; d--) { + size_t dim_i = i_tmp % lhs_dims[d]; + lhs_i += dim_i * lhs_strides[d]; + rhs_i += dim_i * rhs_strides[d]; + i_tmp /= lhs_dims[d]; + } + + // figure out if we are using lhs or rhs for this `i` + size_t i_along_axis = i_tmp % (lhs_dims[axis] + rhs_dims[axis]); + i_tmp /= (lhs_dims[axis] + rhs_dims[axis]); + + // striding lhs & rhs along the rest of the axes + for (int d = axis - 1; d >= 0;d--) { + size_t dim_i = i_tmp % lhs_dims[d]; + lhs_i += dim_i * lhs_strides[d]; + rhs_i += dim_i * rhs_strides[d]; + i_tmp /= lhs_dims[d]; + } + + if (i_along_axis < lhs_dims[axis]) { + atomicAdd(grad_lhs + lhs_i + i_along_axis * lhs_strides[axis], grad_out[i]); + } else { + atomicAdd(grad_rhs + rhs_i + (i_along_axis - lhs_dims[axis]) * rhs_strides[axis], grad_out[i]); + } } } "; diff --git a/src/tensor_ops/conv2d/conv2d.cu b/src/tensor_ops/conv2d/conv2d.cu index 26ba88d23..bb14e53c6 100644 --- a/src/tensor_ops/conv2d/conv2d.cu +++ b/src/tensor_ops/conv2d/conv2d.cu @@ -22,33 +22,31 @@ __device__ void unfold_input_into_patches( const size_t *strides, // 4d image strides T *patches // 6d (Batch, Groups * Channels, KernelSize, KernelSize, HeightOut, WidthOut) ) { - unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; - if (i >= op.batch * op.groups * op.chan_in * op.h_out * op.w_out) { - return; - } - - unsigned int idx = i; - const size_t ow = idx % op.w_out; - idx /= op.w_out; - const size_t oh = idx % op.h_out; - idx /= op.h_out; - const size_t c = idx % (op.chan_in * op.groups); - idx /= (op.chan_in * op.groups); - const size_t b = idx % op.batch; - - image += b * strides[0] + c * strides[1]; - patches += oh * op.w_out + ow; - patches += c * (op.kernel * op.kernel * op.h_out * op.w_out); - patches += b * (op.groups * op.chan_in * op.kernel * op.kernel * op.h_out * op.w_out); - - T zero = 0.0; - - for (int k1 = 0;k1 < op.kernel;k1++) { - const size_t y = oh * op.stride + op.dilation * k1 - op.padding; - for (int k2 = 0;k2 < op.kernel;k2++) { - const size_t x = ow * op.stride + op.dilation * k2 - op.padding; - *patches = (y >= op.h_in || x >= op.w_in) ? zero : image[y * strides[2] + x * strides[3]]; - patches += op.h_out * op.w_out; + const size_t n = op.batch * op.groups * op.chan_in * op.h_out * op.w_out; + for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; i += blockDim.x * gridDim.x) { + unsigned int idx = i; + const size_t ow = idx % op.w_out; + idx /= op.w_out; + const size_t oh = idx % op.h_out; + idx /= op.h_out; + const size_t c = idx % (op.chan_in * op.groups); + idx /= (op.chan_in * op.groups); + const size_t b = idx % op.batch; + + const T *image_i = image + b * strides[0] + c * strides[1]; + T *patches_i = patches + oh * op.w_out + ow; + patches_i += c * (op.kernel * op.kernel * op.h_out * op.w_out); + patches_i += b * (op.groups * op.chan_in * op.kernel * op.kernel * op.h_out * op.w_out); + + T zero = 0.0; + + for (int k1 = 0;k1 < op.kernel;k1++) { + const size_t y = oh * op.stride + op.dilation * k1 - op.padding; + for (int k2 = 0;k2 < op.kernel;k2++) { + const size_t x = ow * op.stride + op.dilation * k2 - op.padding; + *patches_i = (y >= op.h_in || x >= op.w_in) ? zero : image[y * strides[2] + x * strides[3]]; + patches_i += op.h_out * op.w_out; + } } } } @@ -59,40 +57,38 @@ __device__ void unfold_output_into_patches( const T *image_out, // 4d (Batch, ChanOut, HeightOut, WidthOut) T *patches // 6d (Batch, ChanOut, KernelSize, KernelSize, HeightIn, WidthIn) ) { - const unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; - if (i >= op.batch * op.chan_out * op.h_in * op.w_in) { - return; - } - - unsigned int idx = i; - const size_t x = idx % op.w_in; - idx /= op.w_in; - const size_t y = idx % op.h_in; - idx /= op.h_in; - const size_t o = idx % op.chan_out; - idx /= op.chan_out; - const size_t b = idx % op.batch; - - image_out += b * (op.chan_out * op.h_out * op.w_out) + o * (op.h_out * op.w_out); - patches += y * op.w_in + x; - patches += o * (op.kernel * op.kernel * op.h_in * op.w_in); - patches += b * (op.chan_out * op.kernel * op.kernel * op.h_in * op.w_in); - - T zero = 0.0; - - for (int k1 = 0;k1 < op.kernel;k1++) { - const size_t oh_ks = y + op.padding; - const size_t oh_s = oh_ks - op.dilation * k1; - const size_t oh = oh_s / op.stride; - const bool k1_invalid = (oh_ks < op.dilation * k1 || oh_s % op.stride != 0 || oh >= op.h_out); - for (int k2 = 0;k2 < op.kernel;k2++) { - const size_t ow_ks = x + op.padding; - const size_t ow_s = ow_ks - op.dilation * k2; - const size_t ow = ow_s / op.stride; - - const bool invalid = k1_invalid || (ow_ks < op.dilation * k2 || ow_s % op.stride != 0 || ow >= op.w_out); - *patches = invalid ? zero : image_out[oh * op.w_out + ow]; - patches += op.h_in * op.w_in; + const size_t n = op.batch * op.chan_out * op.h_in * op.w_in; + for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; i += blockDim.x * gridDim.x) { + unsigned int idx = i; + const size_t x = idx % op.w_in; + idx /= op.w_in; + const size_t y = idx % op.h_in; + idx /= op.h_in; + const size_t o = idx % op.chan_out; + idx /= op.chan_out; + const size_t b = idx % op.batch; + + const T *image_i = image_out + b * (op.chan_out * op.h_out * op.w_out) + o * (op.h_out * op.w_out); + T *patches_i = patches + y * op.w_in + x; + patches_i += o * (op.kernel * op.kernel * op.h_in * op.w_in); + patches_i += b * (op.chan_out * op.kernel * op.kernel * op.h_in * op.w_in); + + T zero = 0.0; + + for (int k1 = 0;k1 < op.kernel;k1++) { + const size_t oh_ks = y + op.padding; + const size_t oh_s = oh_ks - op.dilation * k1; + const size_t oh = oh_s / op.stride; + const bool k1_invalid = (oh_ks < op.dilation * k1 || oh_s % op.stride != 0 || oh >= op.h_out); + for (int k2 = 0;k2 < op.kernel;k2++) { + const size_t ow_ks = x + op.padding; + const size_t ow_s = ow_ks - op.dilation * k2; + const size_t ow = ow_s / op.stride; + + const bool invalid = k1_invalid || (ow_ks < op.dilation * k2 || ow_s % op.stride != 0 || ow >= op.w_out); + *patches_i = invalid ? zero : image_out[oh * op.w_out + ow]; + patches_i += op.h_in * op.w_in; + } } } } @@ -104,31 +100,29 @@ __device__ void transpose_filters( const size_t *strides, // 4d filters strides T *filters_tr // 5d (Groups, ChanIn, ChanOut/Groups, KernelSize, KernelSize) ) { - unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; - if (i >= op.chan_in * op.chan_out * op.kernel * op.kernel) { - return; - } - + const size_t n = op.chan_in * op.chan_out * op.kernel * op.kernel; const size_t o_per_g = op.chan_out / op.groups; - unsigned int idx = i; - const size_t k2 = idx % op.kernel; - idx /= op.kernel; - const size_t k1 = idx % op.kernel; - idx /= op.kernel; - const size_t c = idx % op.chan_in; - idx /= op.chan_in; - const size_t o = idx % op.chan_out; - const size_t og = o % o_per_g; - const size_t g = o / o_per_g; - - auto i_no = o * strides[0] + c * strides[1] + k1 * strides[2] + k2 * strides[3]; - filters_tr += k2; - filters_tr += k1 * op.kernel; - filters_tr += og * (op.kernel * op.kernel); - filters_tr += c * (o_per_g * op.kernel * op.kernel); - filters_tr += g * (op.chan_in * o_per_g * op.kernel * op.kernel); - *filters_tr = filters[i_no]; + for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; i += blockDim.x * gridDim.x) { + unsigned int idx = i; + const size_t k2 = idx % op.kernel; + idx /= op.kernel; + const size_t k1 = idx % op.kernel; + idx /= op.kernel; + const size_t c = idx % op.chan_in; + idx /= op.chan_in; + const size_t o = idx % op.chan_out; + const size_t og = o % o_per_g; + const size_t g = o / o_per_g; + + auto i_no = o * strides[0] + c * strides[1] + k1 * strides[2] + k2 * strides[3]; + T *filters_tr_i = filters_tr + k2; + filters_tr_i += k1 * op.kernel; + filters_tr_i += og * (op.kernel * op.kernel); + filters_tr_i += c * (o_per_g * op.kernel * op.kernel); + filters_tr_i += g * (op.chan_in * o_per_g * op.kernel * op.kernel); + *filters_tr_i = filters[i_no]; + } } template @@ -138,41 +132,37 @@ __device__ void sum_transposed_filters( T *filters, // 4d (ChanOut, ChanIn, KernelSize, KernelSize) const size_t *strides // 4d filter strides ) { - unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; - auto numel = op.chan_out * op.chan_in * op.kernel * op.kernel; - if (i >= numel) { - return; - } - + const size_t n = op.chan_out * op.chan_in * op.kernel * op.kernel; const size_t o_per_g = op.chan_out / op.groups; - unsigned int idx = i; - const size_t k2 = idx % op.kernel; - idx /= op.kernel; - const size_t k1 = idx % op.kernel; - idx /= op.kernel; - const size_t c = idx % op.chan_in; - idx /= op.chan_in; - const size_t o = idx % op.chan_out; - const size_t og = o % o_per_g; - const size_t g = o / o_per_g; - - auto i_tr = c * (op.chan_out * op.kernel * op.kernel) + o * (op.kernel * op.kernel) + k1 * (op.kernel) + k2; - auto i_no = o * strides[0] + c * strides[1] + k1 * strides[2] + k2 * strides[3]; - - filters_tr += k2; - filters_tr += k1 * op.kernel; - filters_tr += og * (op.kernel * op.kernel); - filters_tr += c * (o_per_g * op.kernel * op.kernel); - filters_tr += g * (op.chan_in * o_per_g * op.kernel * op.kernel); - - T tmp = 0.0; - for (int b = 0; b < op.batch; b++) { - tmp += *filters_tr; - filters_tr += numel; + for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; i += blockDim.x * gridDim.x) { + unsigned int idx = i; + const size_t k2 = idx % op.kernel; + idx /= op.kernel; + const size_t k1 = idx % op.kernel; + idx /= op.kernel; + const size_t c = idx % op.chan_in; + idx /= op.chan_in; + const size_t o = idx % op.chan_out; + const size_t og = o % o_per_g; + const size_t g = o / o_per_g; + + auto i_no = o * strides[0] + c * strides[1] + k1 * strides[2] + k2 * strides[3]; + + const T *filters_tr_i = filters_tr + k2; + filters_tr_i += k1 * op.kernel; + filters_tr_i += og * (op.kernel * op.kernel); + filters_tr_i += c * (o_per_g * op.kernel * op.kernel); + filters_tr_i += g * (op.chan_in * o_per_g * op.kernel * op.kernel); + + T tmp = 0.0; + for (int b = 0; b < op.batch; b++) { + tmp += *filters_tr_i; + filters_tr_i += n; + } + + filters[i_no] += tmp; } - - filters[i_no] += tmp; } #define CONV_OP(TYPENAME, UNFOLD_INPUT, UNFOLD_OUTPUT, TR_FILTERS, SUM_TR_FILTERS) \ diff --git a/src/tensor_ops/conv2d/mod.rs b/src/tensor_ops/conv2d/mod.rs index 8f00c99bf..2c34b8b9c 100644 --- a/src/tensor_ops/conv2d/mod.rs +++ b/src/tensor_ops/conv2d/mod.rs @@ -166,8 +166,17 @@ impl impl TryConv2D for ( - Tensor<(>::Output, H, W), E, D, T>, - Tensor<(OutChan, InpChan, Kernel, Kernel), E, D>, + Tensor<(InpChan, H, W), E, D, T>, + Tensor< + ( + OutChan, + >::Output, + Kernel, + Kernel, + ), + E, + D, + >, ) where InpChan: Dim, @@ -182,8 +191,8 @@ where E: Dtype, D: Conv2DKernel + crate::tensor_ops::reshape_to::ReshapeKernel, T: Tape, - InpChan: std::ops::Mul, - >::Output: Dim, + InpChan: std::ops::Div, + >::Output: Dim, (H, Kernel): TryConv2D, (W, Kernel): TryConv2D, <(H, Kernel) as TryConv2D>::Convolved: Dim, @@ -220,8 +229,17 @@ where impl TryConv2D for ( - Tensor<(Batch, >::Output, H, W), E, D, T>, - Tensor<(OutChan, InpChan, Kernel, Kernel), E, D>, + Tensor<(Batch, InpChan, H, W), E, D, T>, + Tensor< + ( + OutChan, + >::Output, + Kernel, + Kernel, + ), + E, + D, + >, ) where InpChan: Dim, @@ -237,8 +255,8 @@ where E: Dtype, D: Conv2DKernel, T: Tape, - InpChan: std::ops::Mul, - >::Output: Dim, + InpChan: std::ops::Div, + >::Output: Dim, (H, Kernel): TryConv2D, (W, Kernel): TryConv2D, <(H, Kernel) as TryConv2D>::Convolved: Dim, diff --git a/src/tensor_ops/dropout/dropout.cu b/src/tensor_ops/dropout/dropout.cu index 47089efed..96172bbc2 100644 --- a/src/tensor_ops/dropout/dropout.cu +++ b/src/tensor_ops/dropout/dropout.cu @@ -8,14 +8,12 @@ extern "C" __global__ void FWD( \ const bool *mask, \ TYPENAME *out \ ) { \ - unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; \ - if (i >= numel) { \ - return; \ - } \ TYPENAME zero = 0.0; \ TYPENAME one = 1.0; \ - TYPENAME scalar = mask[i] ? zero : (one / (one - prob)); \ - out[i] = inp[i] * scalar; \ + for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \ + TYPENAME scalar = mask[i] ? zero : (one / (one - prob)); \ + out[i] = inp[i] * scalar; \ + } \ } \ extern "C" __global__ void BWD( \ const TYPENAME prob, \ @@ -24,13 +22,11 @@ extern "C" __global__ void BWD( \ TYPENAME *grad_inp, \ const TYPENAME *grad_out \ ) { \ - unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; \ - if (i >= numel) { \ - return; \ - } \ TYPENAME zero = 0.0; \ TYPENAME one = 1.0; \ - grad_inp[i] += mask[i] ? zero : (grad_out[i] / (one - prob)); \ + for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \ + grad_inp[i] += mask[i] ? zero : (grad_out[i] / (one - prob)); \ + } \ } DROPOUT(__half, dropout_fwd_f16, dropout_bwd_f16); diff --git a/src/tensor_ops/pool2d/pool2d.cu b/src/tensor_ops/pool2d/pool2d.cu index a84f0f320..0a58a36af 100644 --- a/src/tensor_ops/pool2d/pool2d.cu +++ b/src/tensor_ops/pool2d/pool2d.cu @@ -79,40 +79,37 @@ __device__ void pool2d_fwd( const T *inp, // 4d (Batch, Channels, Height, Width) T *out // 4d (Batch, Channels, HeightOut, WidthOut) ) { - unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; const size_t numel = op.batch * op.chan * op.h_out * op.w_out; - if (i >= numel) { - return; - } - - unsigned int idx = i; - const size_t ow = idx % op.w_out; - idx /= op.w_out; - const size_t oh = idx % op.h_out; - idx /= op.h_out; - const size_t c = idx % op.chan; - idx /= op.chan; - const size_t b = idx % op.batch; - idx /= op.batch; + for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { + unsigned int idx = i; + const size_t ow = idx % op.w_out; + idx /= op.w_out; + const size_t oh = idx % op.h_out; + idx /= op.h_out; + const size_t c = idx % op.chan; + idx /= op.chan; + const size_t b = idx % op.batch; + idx /= op.batch; + + T tmp = init(op); + for(size_t k1 = 0; k1 < op.kernel; k1++) { + for (size_t k2 = 0; k2 < op.kernel; k2++) { + const size_t y_plus_p = oh * op.stride + op.dilation * k1; + if (y_plus_p < op.padding) { continue; } + const size_t y = y_plus_p - op.padding; + if (y >= op.h_in) { continue; } + const size_t x_plus_p = ow * op.stride + op.dilation * k2; + if (x_plus_p < op.padding) { continue; } + const size_t x = x_plus_p - op.padding; + if (x >= op.w_in) { continue; } - T tmp = init(op); - for(size_t k1 = 0; k1 < op.kernel; k1++) { - for (size_t k2 = 0; k2 < op.kernel; k2++) { - const size_t y_plus_p = oh * op.stride + op.dilation * k1; - if (y_plus_p < op.padding) { continue; } - const size_t y = y_plus_p - op.padding; - if (y >= op.h_in) { continue; } - const size_t x_plus_p = ow * op.stride + op.dilation * k2; - if (x_plus_p < op.padding) { continue; } - const size_t x = x_plus_p - op.padding; - if (x >= op.w_in) { continue; } - - auto inp_i = b * inp_strides[0] + c * inp_strides[1] + y * inp_strides[2] + x * inp_strides[3]; - tmp = accum(op, tmp, inp[inp_i]); + auto inp_i = b * inp_strides[0] + c * inp_strides[1] + y * inp_strides[2] + x * inp_strides[3]; + tmp = accum(op, tmp, inp[inp_i]); + } } + + out[i] = normalize(op, tmp, op.kernel * op.kernel); } - - out[i] = normalize(op, tmp, op.kernel * op.kernel); } template @@ -125,46 +122,43 @@ __device__ void pool2d_bwd( const T *out, // 4d (Batch, Channels, HeightOut, WidthOut) const T *grad_out ) { - unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; const size_t numel = op.batch * op.chan * op.h_in * op.w_in; - if (i >= numel) { - return; - } - - unsigned int idx = i; - const size_t x = idx % op.w_in; - idx /= op.w_in; - const size_t y = idx % op.h_in; - idx /= op.h_in; - const size_t c = idx % op.chan; - idx /= op.chan; - const size_t b = idx % op.batch; - idx /= op.batch; - - const T inp_v = inp[i]; - - T tmp = 0.0; - for(size_t k1 = 0; k1 < op.kernel; k1++) { - for (size_t k2 = 0; k2 < op.kernel; k2++) { - size_t oh = y + op.padding; - if (oh < op.dilation * k1) { continue; } - oh -= op.dilation * k1; - if (oh % op.stride != 0) { continue; } - oh /= op.stride; - if (oh >= op.h_out) { continue; } - - size_t ow = x + op.padding; - if (ow < op.dilation * k2) { continue; } - ow -= op.dilation * k2; - if (ow % op.stride != 0) { continue; } - ow /= op.stride; - if (ow >= op.w_out) { continue; } - - auto out_i = b * out_strides[0] + c * out_strides[1] + oh * out_strides[2] + ow * out_strides[3]; - tmp += filter(op, grad_out[out_i], out[out_i], inp_v); + for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { + unsigned int idx = i; + const size_t x = idx % op.w_in; + idx /= op.w_in; + const size_t y = idx % op.h_in; + idx /= op.h_in; + const size_t c = idx % op.chan; + idx /= op.chan; + const size_t b = idx % op.batch; + idx /= op.batch; + + const T inp_v = inp[i]; + + T tmp = 0.0; + for(size_t k1 = 0; k1 < op.kernel; k1++) { + for (size_t k2 = 0; k2 < op.kernel; k2++) { + size_t oh = y + op.padding; + if (oh < op.dilation * k1) { continue; } + oh -= op.dilation * k1; + if (oh % op.stride != 0) { continue; } + oh /= op.stride; + if (oh >= op.h_out) { continue; } + + size_t ow = x + op.padding; + if (ow < op.dilation * k2) { continue; } + ow -= op.dilation * k2; + if (ow % op.stride != 0) { continue; } + ow /= op.stride; + if (ow >= op.w_out) { continue; } + + auto out_i = b * out_strides[0] + c * out_strides[1] + oh * out_strides[2] + ow * out_strides[3]; + tmp += filter(op, grad_out[out_i], out[out_i], inp_v); + } } + grad_inp[i] += normalize(op, tmp, op.kernel * op.kernel); } - grad_inp[i] += normalize(op, tmp, op.kernel * op.kernel); } #define POOL_OP(TYPENAME, fwd, bwd) \ diff --git a/src/tensor_ops/reshape_to/cuda_kernel.rs b/src/tensor_ops/reshape_to/cuda_kernel.rs index b2e9661fd..607434f36 100644 --- a/src/tensor_ops/reshape_to/cuda_kernel.rs +++ b/src/tensor_ops/reshape_to/cuda_kernel.rs @@ -119,20 +119,15 @@ extern \"C\" __global__ void reshape_fwd( const $T *inp, $T *out ) { - unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; - if (i >= numel) { - return; - } - const size_t *inp_dims = info; const size_t *inp_strides = info + inp_num_dims; const size_t *out_dims = info + 2 * inp_num_dims; const size_t *out_strides = info + 2 * inp_num_dims + out_num_dims; - - unsigned int inp_i = get_strided_index(i, inp_num_dims, inp_dims, inp_strides); - unsigned int out_i = get_strided_index(i, out_num_dims, out_dims, out_strides); - - out[out_i] = inp[inp_i]; + for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { + unsigned int inp_i = get_strided_index(i, inp_num_dims, inp_dims, inp_strides); + unsigned int out_i = get_strided_index(i, out_num_dims, out_dims, out_strides); + out[out_i] = inp[inp_i]; + } } "; @@ -153,19 +148,14 @@ extern \"C\" __global__ void reshape_bwd( $T *grad_inp, const $T *grad_out ) { - unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; - if (i >= numel) { - return; - } - const size_t *inp_dims = info; const size_t *inp_strides = info + inp_num_dims; const size_t *out_dims = info + 2 * inp_num_dims; const size_t *out_strides = info + 2 * inp_num_dims + out_num_dims; - - unsigned int inp_i = get_strided_index(i, inp_num_dims, inp_dims, inp_strides); - unsigned int out_i = get_strided_index(i, out_num_dims, out_dims, out_strides); - - atomicAdd(grad_inp + inp_i, grad_out[out_i]); + for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { + unsigned int inp_i = get_strided_index(i, inp_num_dims, inp_dims, inp_strides); + unsigned int out_i = get_strided_index(i, out_num_dims, out_dims, out_strides); + atomicAdd(grad_inp + inp_i, grad_out[out_i]); + } } "; diff --git a/src/tensor_ops/roll/roll.cu b/src/tensor_ops/roll/roll.cu index 375e73b32..6c5cec1f4 100644 --- a/src/tensor_ops/roll/roll.cu +++ b/src/tensor_ops/roll/roll.cu @@ -16,32 +16,29 @@ __device__ void roll_fwd( const T *inp, T *out ) { - unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; - if (i >= numel) { - return; + for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { + auto idx = i; + const T item = inp[get_strided_index(i, num_dims, dims, inp_strides)]; + + size_t out_i = 0; + for (int d = num_dims - 1; d > op.axis; d--) { + size_t dim_i = idx % dims[d]; + out_i += dim_i * out_strides[d]; + idx /= dims[d]; + } + + size_t dim_i = idx % dims[op.axis]; + size_t new_dim_i = (dim_i + op.amount) % dims[op.axis]; + out_i += new_dim_i * out_strides[op.axis]; + idx /= dims[op.axis]; + + for (int d = op.axis - 1; d >= 0;d--) { + size_t dim_i = idx % dims[d]; + out_i += dim_i * out_strides[d]; + idx /= dims[d]; + } + out[out_i] = item; } - - const T item = inp[get_strided_index(i, num_dims, dims, inp_strides)]; - - size_t out_i = 0; - for (int d = num_dims - 1; d > op.axis; d--) { - size_t dim_i = i % dims[d]; - out_i += dim_i * out_strides[d]; - i /= dims[d]; - } - - size_t dim_i = i % dims[op.axis]; - size_t new_dim_i = (dim_i + op.amount) % dims[op.axis]; - out_i += new_dim_i * out_strides[op.axis]; - i /= dims[op.axis]; - - for (int d = op.axis - 1; d >= 0;d--) { - size_t dim_i = i % dims[d]; - out_i += dim_i * out_strides[d]; - i /= dims[d]; - } - - out[out_i] = item; } template @@ -55,32 +52,30 @@ __device__ void roll_bwd( T *grad_inp, const T *grad_out ) { - unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; - if (i >= numel) { - return; + for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { + auto idx = i; + const size_t inp_i = get_strided_index(i, num_dims, dims, inp_strides); + + size_t out_i = 0; + for (int d = num_dims - 1; d > op.axis; d--) { + size_t dim_i = idx % dims[d]; + out_i += dim_i * out_strides[d]; + idx /= dims[d]; + } + + size_t dim_i = idx % dims[op.axis]; + size_t new_dim_i = (dim_i + op.amount) % dims[op.axis]; + out_i += new_dim_i * out_strides[op.axis]; + idx /= dims[op.axis]; + + for (int d = op.axis - 1; d >= 0;d--) { + size_t dim_i = idx % dims[d]; + out_i += dim_i * out_strides[d]; + idx /= dims[d]; + } + + atomicAdd(grad_inp + inp_i, grad_out[out_i]); } - - const size_t inp_i = get_strided_index(i, num_dims, dims, inp_strides); - - size_t out_i = 0; - for (int d = num_dims - 1; d > op.axis; d--) { - size_t dim_i = i % dims[d]; - out_i += dim_i * out_strides[d]; - i /= dims[d]; - } - - size_t dim_i = i % dims[op.axis]; - size_t new_dim_i = (dim_i + op.amount) % dims[op.axis]; - out_i += new_dim_i * out_strides[op.axis]; - i /= dims[op.axis]; - - for (int d = op.axis - 1; d >= 0;d--) { - size_t dim_i = i % dims[d]; - out_i += dim_i * out_strides[d]; - i /= dims[d]; - } - - atomicAdd(grad_inp + inp_i, grad_out[out_i]); } #define ROLL(TY, FWD, BWD) \ diff --git a/src/tensor_ops/select_and_gather/gather.cu b/src/tensor_ops/select_and_gather/gather.cu index 747c033a4..56f93e28d 100644 --- a/src/tensor_ops/select_and_gather/gather.cu +++ b/src/tensor_ops/select_and_gather/gather.cu @@ -56,17 +56,12 @@ __device__ void gather_fwd( T *out, const size_t out_num_dims ) { - unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; - if (i >= numel) { - return; + for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { + unsigned int out_i = i; + unsigned int inp_i = + get_gathered_index(i, inp_num_dims, inp_dims, inp_strides, idx, idx_num_dims, idx_dims, idx_strides, out_num_dims); + out[out_i] = inp[inp_i]; } - - unsigned int out_i = i; - unsigned int inp_i = - get_gathered_index(i, inp_num_dims, inp_dims, inp_strides, idx, idx_num_dims, idx_dims, idx_strides, out_num_dims); - - out[out_i] = inp[inp_i]; - // out[out_i] = inp_i; } template @@ -83,16 +78,12 @@ __device__ void gather_bwd( const T *grad_out, const size_t out_num_dims ) { - unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; - if (i >= numel) { - return; + for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { + unsigned int out_i = i; + unsigned int inp_i = + get_gathered_index(i, inp_num_dims, inp_dims, inp_strides, idx, idx_num_dims, idx_dims, idx_strides, out_num_dims); + atomicAdd(grad_inp + inp_i, grad_out[out_i]); } - - unsigned int out_i = i; - unsigned int inp_i = - get_gathered_index(i, inp_num_dims, inp_dims, inp_strides, idx, idx_num_dims, idx_dims, idx_strides, out_num_dims); - - atomicAdd(grad_inp + inp_i, grad_out[out_i]); } #define GATHER(TYPENAME, FWD, BWD) \ diff --git a/src/tensor_ops/select_and_gather/select.cu b/src/tensor_ops/select_and_gather/select.cu index 21242c0d6..1f9f68a05 100644 --- a/src/tensor_ops/select_and_gather/select.cu +++ b/src/tensor_ops/select_and_gather/select.cu @@ -44,16 +44,12 @@ __device__ void select_fwd( const size_t *out_dims, const size_t *out_strides ) { - unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; - if (i >= numel) { - return; + for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { + unsigned int out_i = get_strided_index(i, inp_num_dims - 1, out_dims, out_strides); + unsigned int inp_i = + get_selected_index(i, inp_num_dims, inp_dims, inp_strides, idx, idx_num_dims, idx_dims, idx_strides); + out[out_i] = inp[inp_i]; } - - unsigned int out_i = get_strided_index(i, inp_num_dims - 1, out_dims, out_strides); - unsigned int inp_i = - get_selected_index(i, inp_num_dims, inp_dims, inp_strides, idx, idx_num_dims, idx_dims, idx_strides); - - out[out_i] = inp[inp_i]; } template @@ -71,16 +67,12 @@ __device__ void select_bwd( const size_t *out_dims, const size_t *out_strides ) { - unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; - if (i >= numel) { - return; + for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { + unsigned int out_i = get_strided_index(i, inp_num_dims - 1, out_dims, out_strides); + unsigned int inp_i = + get_selected_index(i, inp_num_dims, inp_dims, inp_strides, idx, idx_num_dims, idx_dims, idx_strides); + atomicAdd(grad_inp + inp_i, grad_out[out_i]); } - - unsigned int out_i = get_strided_index(i, inp_num_dims - 1, out_dims, out_strides); - unsigned int inp_i = - get_selected_index(i, inp_num_dims, inp_dims, inp_strides, idx, idx_num_dims, idx_dims, idx_strides); - - atomicAdd(grad_inp + inp_i, grad_out[out_i]); } #define SELECT(TYPENAME, FWD, BWD) \ diff --git a/src/tensor_ops/slice/slice.cu b/src/tensor_ops/slice/slice.cu index cb4f8325a..8538fa3f0 100644 --- a/src/tensor_ops/slice/slice.cu +++ b/src/tensor_ops/slice/slice.cu @@ -11,13 +11,10 @@ __device__ void slice_fwd( const T *inp, T *out ) { - unsigned int out_i = blockIdx.x * blockDim.x + threadIdx.x; - if (out_i >= numel) { - return; + for (unsigned int out_i = blockIdx.x * blockDim.x + threadIdx.x; out_i < numel; out_i += blockDim.x * gridDim.x) { + unsigned int inp_i = offset + get_strided_index(out_i, num_dims, dims, strides); + out[out_i] = inp[inp_i]; } - - unsigned int inp_i = offset + get_strided_index(out_i, num_dims, dims, strides); - out[out_i] = inp[inp_i]; } template @@ -30,14 +27,11 @@ __device__ void slice_bwd( T *grad_inp, const T *out ) { - unsigned int out_i = blockIdx.x * blockDim.x + threadIdx.x; - if (out_i >= numel) { - return; + for (unsigned int out_i = blockIdx.x * blockDim.x + threadIdx.x; out_i < numel; out_i += blockDim.x * gridDim.x) { + unsigned int inp_i = offset + get_strided_index(out_i, num_dims, dims, strides); + // TODO (maybe): use chunk_sum to speed this up + atomicAdd(grad_inp + inp_i, out[out_i]); } - - unsigned int inp_i = offset + get_strided_index(out_i, num_dims, dims, strides); - // TODO (maybe): use chunk_sum to speed this up - atomicAdd(grad_inp + inp_i, out[out_i]); } #define SLICE_FWD(TYPENAME, FN) \ diff --git a/src/tensor_ops/stack/cuda_kernel.rs b/src/tensor_ops/stack/cuda_kernel.rs index ce9ef31b9..d6a783fbc 100644 --- a/src/tensor_ops/stack/cuda_kernel.rs +++ b/src/tensor_ops/stack/cuda_kernel.rs @@ -84,7 +84,8 @@ impl super::StackKernel for Cuda { const BWD_KERNEL: &str = " #include \"cuda_fp16.h\" extern \"C\" __global__ void stack_bwd(const size_t numel, const $Ty *inp, $Ty *out) { - unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; - if (i < numel) { out[i] += inp[i]; } + for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { + out[i] += inp[i]; + } } "; diff --git a/src/tensor_ops/to_dtype/cuda_kernel.rs b/src/tensor_ops/to_dtype/cuda_kernel.rs index b5597a422..5e6d232c6 100644 --- a/src/tensor_ops/to_dtype/cuda_kernel.rs +++ b/src/tensor_ops/to_dtype/cuda_kernel.rs @@ -16,8 +16,9 @@ typedef int intptr_t; #endif #include \"cuda_fp16.h\" extern \"C\" __global__ void kernel(const size_t n, const $Src *inp, $Dst *out) { - unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; - if (i < n) { out[i] = inp[i]; } + for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; i += blockDim.x * gridDim.x) { + out[i] = inp[i]; + } }"; impl super::ToDtypeKernel for Cuda { diff --git a/src/tensor_ops/upscale2d/upscale2d.cu b/src/tensor_ops/upscale2d/upscale2d.cu index 2953960ea..c3197388c 100644 --- a/src/tensor_ops/upscale2d/upscale2d.cu +++ b/src/tensor_ops/upscale2d/upscale2d.cu @@ -16,29 +16,26 @@ __device__ void nearest_upscale2d_fwd( const T *inp, // 4d (Batch, Channels, Height, Width) T *out // 4d (Batch, Channels, HeightOut, WidthOut) ) { - unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; - if (i >= op.batch * op.chan * op.h_out * op.w_out) { - return; - } - + const size_t n = op.batch * op.chan * op.h_out * op.w_out; float h_scale = static_cast(op.h_in)/static_cast(op.h_out); float w_scale = static_cast(op.w_in)/static_cast(op.w_out); - - unsigned int idx = i; - const size_t ow = idx % op.w_out; - idx /= op.w_out; - const size_t oh = idx % op.h_out; - idx /= op.h_out; - const size_t c = idx % op.chan; - idx /= op.chan; - const size_t b = idx % op.batch; - - size_t ih = min(static_cast(h_scale * oh), op.h_in - 1); - size_t iw = min(static_cast(w_scale * ow), op.w_in - 1); - - size_t inp_i = b * inp_strides[0] + c * inp_strides[1] + ih * inp_strides[2] + iw * inp_strides[3]; + for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; i += blockDim.x * gridDim.x) { + unsigned int idx = i; + const size_t ow = idx % op.w_out; + idx /= op.w_out; + const size_t oh = idx % op.h_out; + idx /= op.h_out; + const size_t c = idx % op.chan; + idx /= op.chan; + const size_t b = idx % op.batch; + + size_t ih = min(static_cast(h_scale * oh), op.h_in - 1); + size_t iw = min(static_cast(w_scale * ow), op.w_in - 1); - out[i] = inp[inp_i]; + size_t inp_i = b * inp_strides[0] + c * inp_strides[1] + ih * inp_strides[2] + iw * inp_strides[3]; + + out[i] = inp[inp_i]; + } } template @@ -48,28 +45,26 @@ __device__ void nearest_upscale2d_bwd( T *grad_inp, const T *grad_out // 4d (Batch, Channels, HeightOut, WidthOut) ) { - unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; - if (i >= op.batch * op.chan * op.h_out * op.w_out) { - return; - } - + const size_t n = op.batch * op.chan * op.h_out * op.w_out; float h_scale = static_cast(op.h_in)/static_cast(op.h_out); float w_scale = static_cast(op.w_in)/static_cast(op.w_out); - unsigned int idx = i; - const size_t ow = idx % op.w_out; - idx /= op.w_out; - const size_t oh = idx % op.h_out; - idx /= op.h_out; - const size_t c = idx % op.chan; - idx /= op.chan; - const size_t b = idx % op.batch; - - size_t ih = min(static_cast(h_scale * oh), op.h_in - 1); - size_t iw = min(static_cast(w_scale * ow), op.w_in - 1); - - size_t inp_i = b * inp_strides[0] + c * inp_strides[1] + ih * inp_strides[2] + iw * inp_strides[3]; - atomicAdd(grad_inp + inp_i, grad_out[i]); + for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; i += blockDim.x * gridDim.x) { + unsigned int idx = i; + const size_t ow = idx % op.w_out; + idx /= op.w_out; + const size_t oh = idx % op.h_out; + idx /= op.h_out; + const size_t c = idx % op.chan; + idx /= op.chan; + const size_t b = idx % op.batch; + + size_t ih = min(static_cast(h_scale * oh), op.h_in - 1); + size_t iw = min(static_cast(w_scale * ow), op.w_in - 1); + + size_t inp_i = b * inp_strides[0] + c * inp_strides[1] + ih * inp_strides[2] + iw * inp_strides[3]; + atomicAdd(grad_inp + inp_i, grad_out[i]); + } } template @@ -79,41 +74,37 @@ __device__ void bilinear_upscale2d_fwd( const T *inp, // 4d (Batch, Channels, Height, Width) T *out // 4d (Batch, Channels, HeightOut, WidthOut) ) { - unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; - if (i >= op.batch * op.chan * op.h_out * op.w_out) { - return; - } - + const size_t n = op.batch * op.chan * op.h_out * op.w_out; float h_scale = ((float)op.h_in-1)/(op.h_out-1); float w_scale = ((float)op.w_in-1)/(op.w_out-1); - - unsigned int idx = i; - const size_t ow = idx % op.w_out; - idx /= op.w_out; - const size_t oh = idx % op.h_out; - idx /= op.h_out; - const size_t c = idx % op.chan; - idx /= op.chan; - const size_t b = idx % op.batch; - - size_t y0 = min(static_cast(h_scale * oh), op.h_in - 1); - size_t y1 = min(y0 + 1, op.h_in - 1); - size_t x0 = min(static_cast(w_scale * ow), op.w_in - 1); - size_t x1 = min(x0 + 1, op.w_in - 1); - - T hs = h_scale * oh - y0; - T ws = w_scale * ow - x0; - - inp += b * inp_strides[0] + c * inp_strides[1]; - T one = 1.0; - - T ll = inp[y0 * inp_strides[2] + x0 * inp_strides[3]] * (one-hs) * (one-ws); - T lh = inp[y0 * inp_strides[2] + x1 * inp_strides[3]] * (one-hs) * ws; - T hl = inp[y1 * inp_strides[2] + x0 * inp_strides[3]] * hs * (one-ws); - T hh = inp[y1 * inp_strides[2] + x1 * inp_strides[3]] * hs * ws; - - out[i] = ll + lh + hl + hh; + for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; i += blockDim.x * gridDim.x) { + unsigned int idx = i; + const size_t ow = idx % op.w_out; + idx /= op.w_out; + const size_t oh = idx % op.h_out; + idx /= op.h_out; + const size_t c = idx % op.chan; + idx /= op.chan; + const size_t b = idx % op.batch; + + size_t y0 = min(static_cast(h_scale * oh), op.h_in - 1); + size_t y1 = min(y0 + 1, op.h_in - 1); + size_t x0 = min(static_cast(w_scale * ow), op.w_in - 1); + size_t x1 = min(x0 + 1, op.w_in - 1); + + T hs = h_scale * oh - y0; + T ws = w_scale * ow - x0; + + const T *inp_i = inp + b * inp_strides[0] + c * inp_strides[1]; + + T ll = inp_i[y0 * inp_strides[2] + x0 * inp_strides[3]] * (one-hs) * (one-ws); + T lh = inp_i[y0 * inp_strides[2] + x1 * inp_strides[3]] * (one-hs) * ws; + T hl = inp_i[y1 * inp_strides[2] + x0 * inp_strides[3]] * hs * (one-ws); + T hh = inp_i[y1 * inp_strides[2] + x1 * inp_strides[3]] * hs * ws; + + out[i] = ll + lh + hl + hh; + } } template @@ -123,41 +114,37 @@ __device__ void bilinear_upscale2d_bwd( T *grad_inp, // 4d (Batch, Channels, Height, Width) const T *grad_out // 4d (Batch, Channels, HeightOut, WidthOut) ) { - unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; - if (i >= op.batch * op.chan * op.h_out * op.w_out) { - return; - } - + const size_t n = op.batch * op.chan * op.h_out * op.w_out; float h_scale = ((float)op.h_in-1)/(op.h_out-1); float w_scale = ((float)op.w_in-1)/(op.w_out-1); - - unsigned int idx = i; - const size_t ow = idx % op.w_out; - idx /= op.w_out; - const size_t oh = idx % op.h_out; - idx /= op.h_out; - const size_t c = idx % op.chan; - idx /= op.chan; - const size_t b = idx % op.batch; - - size_t y0 = min(static_cast(h_scale * oh), op.h_in - 1); - size_t y1 = min(y0 + 1, op.h_in - 1); - size_t x0 = min(static_cast(w_scale * ow), op.w_in - 1); - size_t x1 = min(x0 + 1, op.w_in - 1); - - T hs = h_scale * oh - y0; - T ws = w_scale * ow - x0; - - T go = grad_out[i]; - - grad_inp += b * inp_strides[0] + c * inp_strides[1]; - const T one = 1.0; + for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; i += blockDim.x * gridDim.x) { + unsigned int idx = i; + const size_t ow = idx % op.w_out; + idx /= op.w_out; + const size_t oh = idx % op.h_out; + idx /= op.h_out; + const size_t c = idx % op.chan; + idx /= op.chan; + const size_t b = idx % op.batch; + + size_t y0 = min(static_cast(h_scale * oh), op.h_in - 1); + size_t y1 = min(y0 + 1, op.h_in - 1); + size_t x0 = min(static_cast(w_scale * ow), op.w_in - 1); + size_t x1 = min(x0 + 1, op.w_in - 1); + + T hs = h_scale * oh - y0; + T ws = w_scale * ow - x0; + + T go = grad_out[i]; + + T *grad_inp_i = grad_inp + b * inp_strides[0] + c * inp_strides[1]; - atomicAdd(grad_inp + y0 * inp_strides[2] + x0 * inp_strides[3], go * (one-hs) * (one-ws)); - atomicAdd(grad_inp + y0 * inp_strides[2] + x1 * inp_strides[3], go * (one-hs) * ws); - atomicAdd(grad_inp + y1 * inp_strides[2] + x0 * inp_strides[3], go * hs * (one-ws)); - atomicAdd(grad_inp + y1 * inp_strides[2] + x1 * inp_strides[3], go * hs * ws); + atomicAdd(grad_inp_i + y0 * inp_strides[2] + x0 * inp_strides[3], go * (one-hs) * (one-ws)); + atomicAdd(grad_inp_i + y0 * inp_strides[2] + x1 * inp_strides[3], go * (one-hs) * ws); + atomicAdd(grad_inp_i + y1 * inp_strides[2] + x0 * inp_strides[3], go * hs * (one-ws)); + atomicAdd(grad_inp_i + y1 * inp_strides[2] + x1 * inp_strides[3], go * hs * ws); + } } #define UPSCALE_OP(TYPENAME, fwd, bwd, fwd_FN, bwd_FN) \ diff --git a/src/tensor_ops/utilities/binary_op_macros.cuh b/src/tensor_ops/utilities/binary_op_macros.cuh index 9878c3239..b79a4d81b 100644 --- a/src/tensor_ops/utilities/binary_op_macros.cuh +++ b/src/tensor_ops/utilities/binary_op_macros.cuh @@ -10,32 +10,25 @@ extern "C" __global__ void FORWARD( \ const TYPENAME *rhs, \ TYPENAME *out \ ) { \ - unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; \ - if (i >= numel) { \ - return; \ - } \ -\ const size_t *dims = info; \ const size_t *lhs_strides = info + num_dims; \ const size_t *rhs_strides = info + 2 * num_dims; \ -\ - unsigned int tmp_i = i; \ - unsigned int lhs_i = 0; \ - unsigned int rhs_i = 0; \ - for (int d = num_dims - 1; d >= 0; d--) { \ - unsigned int i_dim = tmp_i % dims[d]; \ - lhs_i += i_dim * lhs_strides[d]; \ - rhs_i += i_dim * rhs_strides[d]; \ - tmp_i /= dims[d]; \ + for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \ + unsigned int tmp_i = i; \ + unsigned int lhs_i = 0; \ + unsigned int rhs_i = 0; \ + for (int d = num_dims - 1; d >= 0; d--) { \ + unsigned int i_dim = tmp_i % dims[d]; \ + lhs_i += i_dim * lhs_strides[d]; \ + rhs_i += i_dim * rhs_strides[d]; \ + tmp_i /= dims[d]; \ + } \ + TYPENAME x = lhs ? lhs[lhs_i] : out[i]; \ + TYPENAME y = rhs ? rhs[rhs_i] : out[i]; \ + TYPENAME fx; \ + FUNC\ + out[i] = fx; \ } \ -\ - TYPENAME x = lhs ? lhs[lhs_i] : out[i]; \ - TYPENAME y = rhs ? rhs[rhs_i] : out[i]; \ - TYPENAME fx; \ -\ - FUNC\ -\ - out[i] = fx; \ } \ \ extern "C" __global__ void BACKWARD_LHS( \ @@ -49,33 +42,26 @@ extern "C" __global__ void BACKWARD_LHS( \ const TYPENAME *rhs, \ const TYPENAME *grad_out \ ) { \ - unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; \ - if (i >= numel) { \ - return; \ - } \ -\ const size_t *dims = info + 0 * num_dims; \ const size_t *out_strides = info + 1 * num_dims; \ const size_t *rhs_strides = info + 2 * num_dims; \ -\ - unsigned int tmp_i = i; \ - unsigned int out_i = 0; \ - unsigned int rhs_i = 0; \ - for (int d = num_dims - 1; d >= 0; d--) { \ - unsigned int i_dim = tmp_i % dims[d]; \ - out_i += i_dim * out_strides[d]; \ - rhs_i += i_dim * rhs_strides[d]; \ - tmp_i /= dims[d]; \ - } \ - unsigned int lhs_i = i / chunk_len; \ TYPENAME zero = 0.0; \ - TYPENAME x = lhs ? lhs[lhs_i] : zero; \ - TYPENAME y = rhs ? rhs[rhs_i] : zero; \ - TYPENAME go = grad_out[out_i]; \ -\ - TYPENAME dfdx = (DFDX); \ -\ - chunk_sum(chunk_len, dfdx * go, grad_lhs); \ + for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \ + unsigned int tmp_i = i; \ + unsigned int out_i = 0; \ + unsigned int rhs_i = 0; \ + for (int d = num_dims - 1; d >= 0; d--) { \ + unsigned int i_dim = tmp_i % dims[d]; \ + out_i += i_dim * out_strides[d]; \ + rhs_i += i_dim * rhs_strides[d]; \ + tmp_i /= dims[d]; \ + } \ + TYPENAME x = lhs ? lhs[i / chunk_len] : zero; \ + TYPENAME y = rhs ? rhs[rhs_i] : zero; \ + TYPENAME go = grad_out[out_i]; \ + TYPENAME dfdx = (DFDX); \ + chunk_sum(chunk_len, dfdx * go, grad_lhs); \ + } \ } \ \ extern "C" __global__ void BACKWARD_RHS( \ @@ -89,33 +75,26 @@ extern "C" __global__ void BACKWARD_RHS( \ const size_t chunk_len, \ const TYPENAME *grad_out \ ) { \ - unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; \ - if (i >= numel) { \ - return; \ - } \ const size_t *dims = info + 3 * num_dims; \ const size_t *out_strides = info + 4 * num_dims; \ const size_t *lhs_strides = info + 5 * num_dims; \ -\ - unsigned int tmp_i = i; \ - unsigned int lhs_i = 0; \ - unsigned int out_i = 0; \ - for (int d = num_dims - 1; d >= 0; d--) { \ - unsigned int i_dim = tmp_i % dims[d]; \ - lhs_i += i_dim * lhs_strides[d]; \ - out_i += i_dim * out_strides[d]; \ - tmp_i /= dims[d]; \ - } \ - unsigned int rhs_i = i / chunk_len; \ -\ TYPENAME zero = 0.0; \ - TYPENAME x = lhs ? lhs[lhs_i] : zero; \ - TYPENAME y = rhs ? rhs[rhs_i] : zero; \ - TYPENAME go = grad_out[out_i]; \ -\ - TYPENAME dfdy = (DFDY); \ -\ - chunk_sum(chunk_len, dfdy * go, grad_rhs); \ + for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \ + unsigned int tmp_i = i; \ + unsigned int lhs_i = 0; \ + unsigned int out_i = 0; \ + for (int d = num_dims - 1; d >= 0; d--) { \ + unsigned int i_dim = tmp_i % dims[d]; \ + lhs_i += i_dim * lhs_strides[d]; \ + out_i += i_dim * out_strides[d]; \ + tmp_i /= dims[d]; \ + } \ + TYPENAME x = lhs ? lhs[lhs_i] : zero; \ + TYPENAME y = rhs ? rhs[i / chunk_len] : zero; \ + TYPENAME go = grad_out[out_i]; \ + TYPENAME dfdy = (DFDY); \ + chunk_sum(chunk_len, dfdy * go, grad_rhs); \ + } \ } #define BINARY_OP(TYPENAME, FORWARD, BACKWARD_LHS, BACKWARD_RHS, OP_STRUCT, FUNC, DFDX, DFDY) \ diff --git a/src/tensor_ops/utilities/compatibility.cuh b/src/tensor_ops/utilities/compatibility.cuh index 80bcc900e..7da36244a 100644 --- a/src/tensor_ops/utilities/compatibility.cuh +++ b/src/tensor_ops/utilities/compatibility.cuh @@ -5,14 +5,14 @@ // FIXME: the minimum compute capabilities are just guesses since the table is not specific enough -#if __CUDA_ARCH__ < 800 -__device__ __forceinline__ __half __hmax(__half a, __half b) { - return __float2half(fmaxf(__half2float(a), __half2float(b))); -} -__device__ __forceinline__ __half __hmin(__half a, __half b) { - return __float2half(fminf(__half2float(a), __half2float(b))); -} -#endif +// #if __CUDA_ARCH__ < 600 +// __device__ __forceinline__ __half __hmax(__half a, __half b) { +// return __float2half(fmaxf(__half2float(a), __half2float(b))); +// } +// __device__ __forceinline__ __half __hmin(__half a, __half b) { +// return __float2half(fminf(__half2float(a), __half2float(b))); +// } +// #endif #if __CUDA_ARCH__ < 800 __device__ __forceinline__ __half __hmax_nan(__half a, __half b) { diff --git a/src/tensor_ops/utilities/cuda_utils.cuh b/src/tensor_ops/utilities/cuda_utils.cuh index 5915107f8..e81571bee 100644 --- a/src/tensor_ops/utilities/cuda_utils.cuh +++ b/src/tensor_ops/utilities/cuda_utils.cuh @@ -94,30 +94,15 @@ __device__ void chunk_sum( } } -extern "C" __global__ void fill_with_f16(__half *buf, __half value, const size_t numel) { - unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; - if (i >= numel) { - return; - } - buf[i] = value; -} - -extern "C" __global__ void fill_with_f32(float *buf, float value, const size_t numel) { - unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; - if (i >= numel) { - return; - } - buf[i] = value; -} - -extern "C" __global__ void fill_with_f64(double *buf, double value, const size_t numel) { - unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; - if (i >= numel) { - return; +template +__device__ void fill_with(T *buf, T value, const size_t numel) { + for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { + buf[i] = value; } - buf[i] = value; } - +extern "C" __global__ void fill_with_f16(__half *buf, __half value, const size_t numel) { fill_with(buf, value, numel); } +extern "C" __global__ void fill_with_f32(float *buf, float value, const size_t numel) { fill_with(buf, value, numel); } +extern "C" __global__ void fill_with_f64(double *buf, double value, const size_t numel) { fill_with(buf, value, numel); } __device__ __forceinline__ bool isnang(float a) { return isnan(a); } __device__ __forceinline__ bool isnang(double a) { return isnan(a); } diff --git a/src/tensor_ops/utilities/unary_op_macros.cuh b/src/tensor_ops/utilities/unary_op_macros.cuh index 4fafc7a13..4671f51a5 100644 --- a/src/tensor_ops/utilities/unary_op_macros.cuh +++ b/src/tensor_ops/utilities/unary_op_macros.cuh @@ -7,12 +7,10 @@ extern "C" __global__ void FORWARD( \ const TYPENAME *inp, \ TYPENAME *out \ ) { \ - unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; \ - if (i >= numel) { \ - return; \ + for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \ + TYPENAME x = inp ? inp[i] : out[i]; \ + FUNC \ } \ - TYPENAME x = inp ? inp[i] : out[i]; \ - FUNC \ } \ \ extern "C" __global__ void BACKWARD( \ @@ -23,17 +21,14 @@ extern "C" __global__ void BACKWARD( \ const TYPENAME *out, \ const TYPENAME *grad_out \ ) { \ - unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; \ - if (i >= numel) { \ - return; \ - } \ - \ TYPENAME zero = 0.0; \ - TYPENAME x = inp ? inp[i] : zero; \ - TYPENAME y = out ? out[i] : zero; \ - TYPENAME dx; \ - DERIVATIVE \ - grad_inp[i] += dx * grad_out[i]; \ + for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \ + TYPENAME x = inp ? inp[i] : zero; \ + TYPENAME y = out ? out[i] : zero; \ + TYPENAME dx; \ + DERIVATIVE \ + grad_inp[i] += dx * grad_out[i]; \ + } \ } #define UNARY_OP(TYPENAME, FORWARD, BACKWARD, OP_STRUCT, FUNC, DERIVATIVE) \