Skip to content

Commit

Permalink
Merge branch 'main' into conv-transpose-dilated
Browse files Browse the repository at this point in the history
  • Loading branch information
coreylowman committed Jul 3, 2023
2 parents 71fcada + 780b347 commit bbe3b92
Show file tree
Hide file tree
Showing 30 changed files with 715 additions and 786 deletions.
8 changes: 4 additions & 4 deletions build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ mod cuda {
.arg("--query-gpu=compute_cap")
.arg("--format=csv")
.output()
.unwrap();
.expect("`nvidia-smi` failed. Ensure that you have CUDA installed and that `nvidia-smi` is in your PATH.");
let out = std::str::from_utf8(&out.stdout).unwrap();
let mut lines = out.lines();
assert_eq!(lines.next().unwrap(), "compute_cap");
Expand All @@ -136,7 +136,7 @@ mod cuda {
let out = std::process::Command::new("nvcc")
.arg("--list-gpu-code")
.output()
.unwrap();
.expect("`nvcc` failed. Ensure that you have CUDA installed and that `nvcc` is in your PATH.");
let out = std::str::from_utf8(&out.stdout).unwrap();

let out = out.lines().collect::<Vec<&str>>();
Expand Down Expand Up @@ -188,12 +188,12 @@ mod cuda {
.stdout(std::process::Stdio::piped())
.stderr(std::process::Stdio::piped())
.spawn()
.unwrap()
.expect("nvcc failed to start. Ensure that you have CUDA installed and that `nvcc` is in your PATH.")
})
.collect::<Vec<_>>();

for (kernel_path, child) in kernel_paths.iter().zip(children.into_iter()) {
let output = child.wait_with_output().unwrap();
let output = child.wait_with_output().expect("nvcc failed to run. Ensure that you have CUDA installed and that `nvcc` is in your PATH.");
assert!(
output.status.success(),
"nvcc error while compiling {kernel_path:?}:\n\n# stdout\n{:#}\n\n# stderr\n{:#}",
Expand Down
52 changes: 43 additions & 9 deletions src/nn/conv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ impl<
where
E: Dtype,
D: Device<E>,
Const<{ I / G }>: Sized,
Conv2D<I, O, K, S, P, L, G, E, D>: BuildModule<D, E>,
{
type Built = Conv2D<I, O, K, S, P, L, G, E, D>;
Expand Down Expand Up @@ -62,6 +63,7 @@ where
///
/// See [conv animations](https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md) for helpful
/// visualization of all of these parameters.

#[derive(Debug, Clone)]
pub struct Conv2D<
const IN_CHAN: usize,
Expand All @@ -73,8 +75,10 @@ pub struct Conv2D<
const GROUPS: usize,
E: Dtype,
D: Storage<E>,
> {
pub weight: Tensor<Rank4<OUT_CHAN, IN_CHAN, KERNEL_SIZE, KERNEL_SIZE>, E, D>,
> where
Const<{ IN_CHAN / GROUPS }>: Sized,
{
pub weight: Tensor<Rank4<OUT_CHAN, { IN_CHAN / GROUPS }, KERNEL_SIZE, KERNEL_SIZE>, E, D>,
}

impl<
Expand All @@ -89,6 +93,7 @@ impl<
D,
> TensorCollection<E, D> for Conv2D<I, O, K, S, P, L, G, E, D>
where
Const<{ I / G }>: Sized,
E: Dtype + Float + SampleUniform,
D: Device<E>,
{
Expand All @@ -112,9 +117,8 @@ where
}
}

#[cfg(feature = "nightly")]
impl<
const C: usize,
const I: usize,
const O: usize,
const K: usize,
const S: usize,
Expand All @@ -124,19 +128,21 @@ impl<
E,
D,
Img,
> Module<Img> for Conv2D<C, O, K, S, P, L, G, E, D>
> Module<Img> for Conv2D<I, O, K, S, P, L, G, E, D>
where
Const<{ I / G }>: Sized,
E: Dtype,
D: Device<E>,
(Img, Tensor<Rank4<O, C, K, K>, E, D>): TryConv2D<Const<S>, Const<P>, Const<L>, Const<G>>,
(Img, Tensor<Rank4<O, { I / G }, K, K>, E, D>):
TryConv2D<Const<S>, Const<P>, Const<L>, Const<G>>,
{
type Output = <(Img, Tensor<Rank4<O, C, K, K>, E, D>) as TryConv2D<
type Output = <(Img, Tensor<Rank4<O, { I / G }, K, K>, E, D>) as TryConv2D<
Const<S>,
Const<P>,
Const<L>,
Const<G>,
>>::Convolved;
type Error = <(Img, Tensor<Rank4<O, C, K, K>, E, D>) as TryConv2D<
type Error = <(Img, Tensor<Rank4<O, { I / G }, K, K>, E, D>) as TryConv2D<
Const<S>,
Const<P>,
Const<L>,
Expand All @@ -159,10 +165,11 @@ impl<
E: Dtype,
D: Storage<E>,
> NonMutableModule for Conv2D<I, O, K, S, P, L, G, E, D>
where
Const<{ I / G }>: Sized,
{
}

#[cfg(feature = "nightly")]
#[cfg(test)]
mod tests {
use crate::{
Expand All @@ -189,6 +196,33 @@ mod tests {
let _: Tensor<Rank3<2, 6, 6>, _, _, _> = dev.build_module::<Conv2D<3, 2, 3, 2, 2>, TestDtype>().forward(x.clone());
}

#[test]
fn test_grouped_forward_sizes() {
let dev: TestDevice = Default::default();

let x = dev.zeros::<Rank3<16, 10, 10>>();

let m = dev.build_module::<Conv2D<16, 32, 3, 1, 0, 1, 1>, TestDtype>();
let _: Tensor<Rank4<32, 16, 3, 3>, _, _> = m.weight;
let _: Tensor<Rank3<32, 8, 8>, _, _> = m.forward(x.clone());

let m = dev.build_module::<Conv2D<16, 32, 3, 1, 0, 1, 2>, TestDtype>();
let _: Tensor<Rank4<32, 8, 3, 3>, _, _> = m.weight;
let _: Tensor<Rank3<32, 8, 8>, _, _> = m.forward(x.clone());

let m = dev.build_module::<Conv2D<16, 32, 3, 1, 0, 1, 4>, TestDtype>();
let _: Tensor<Rank4<32, 4, 3, 3>, _, _> = m.weight;
let _: Tensor<Rank3<32, 8, 8>, _, _> = m.forward(x.clone());

let m = dev.build_module::<Conv2D<16, 32, 3, 1, 0, 1, 8>, TestDtype>();
let _: Tensor<Rank4<32, 2, 3, 3>, _, _> = m.weight;
let _: Tensor<Rank3<32, 8, 8>, _, _> = m.forward(x.clone());

let m = dev.build_module::<Conv2D<16, 32, 3, 1, 0, 1, 16>, TestDtype>();
let _: Tensor<Rank4<32, 1, 3, 3>, _, _> = m.weight;
let _: Tensor<Rank3<32, 8, 8>, _, _> = m.forward(x);
}

#[rustfmt::skip]
#[test]
fn test_forward_4d_sizes() {
Expand Down
1 change: 1 addition & 0 deletions src/nn/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@ mod add_into;
mod batchnorm1d;
mod batchnorm2d;
mod bias2d;
#[cfg(feature = "nightly")]
mod conv;
mod convtrans;
mod dropout;
Expand Down
51 changes: 23 additions & 28 deletions src/optim/adam/adam.cu
Original file line number Diff line number Diff line change
Expand Up @@ -25,42 +25,37 @@ __device__ void adam_update(
T* moment2,
const T* grad
) {
unsigned int i = blockIdx.x * blockDim.x + threadIdx.x;

if (i >= numel) {
return;
}

T beta1 = cfg.beta1;
T beta2 = cfg.beta2;
T lr = cfg.lr;
T weight_decay = cfg.weight_decay;
T eps = cfg.eps;

T p = param[i];
T g = grad[i];
T m = moment1[i];
T v = moment2[i];
T one = 1.0;
T t = t_int;

if (cfg.weight_decay_type == L2) {
g += weight_decay * p;
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) {
T p = param[i];
T g = grad[i];
T m = moment1[i];
T v = moment2[i];

if (cfg.weight_decay_type == L2) {
g += weight_decay * p;
}

m = m * beta1 + g * (one - beta1);
v = v * beta2 + g * g * (one - beta2);
T m_hat = m * one / (one - powg(beta1, t));
T v_hat = v * one / (one - powg(beta2, t));
g = lr * m_hat / (sqrtg(v_hat) + eps);

if (cfg.weight_decay_type == Decoupled) {
g += (weight_decay * lr) * p;
}

moment1[i] = m;
moment2[i] = v;
param[i] -= g;
}

m = m * beta1 + g * (one - beta1);
v = v * beta2 + g * g * (one - beta2);
T m_hat = m * one / (one - powg(beta1, t));
T v_hat = v * one / (one - powg(beta2, t));
g = lr * m_hat / (sqrtg(v_hat) + eps);

if (cfg.weight_decay_type == Decoupled) {
g += (weight_decay * lr) * p;
}

moment1[i] = m;
moment2[i] = v;
param[i] -= g;
}

#define ADAM(TYPENAME, FN) \
Expand Down
85 changes: 41 additions & 44 deletions src/optim/rmsprop/rmsprop.cu
Original file line number Diff line number Diff line change
Expand Up @@ -27,58 +27,55 @@ __device__ void rmsprop_update(
T* grad_avg,
const T* grad
) {
unsigned int i = blockIdx.x * blockDim.x + threadIdx.x;

if (i >= numel) {
return;
}

T lr = cfg.lr;
T alpha = cfg.alpha;
T eps = cfg.eps;
T momentum_ = cfg.momentum;
T weight_decay = cfg.weight_decay;

T p = param[i];
T g = grad[i];
T s_avg = square_avg[i];
T g_avg = grad_avg[i];
T m = momentum[i];
T one = 1.0;

if (cfg.weight_decay_type == L2) {
g += weight_decay * p;
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) {
T p = param[i];
T g = grad[i];
T s_avg = square_avg[i];
T g_avg = grad_avg[i];
T m = momentum[i];


if (cfg.weight_decay_type == L2) {
g += weight_decay * p;
}

s_avg += (one - alpha) * (g * g - s_avg);

T avg;

if (cfg.centered) {
// ga = a * ga + (1 - a) * g
g_avg += (one - alpha) * (g - g_avg);
avg = sqrtg(s_avg - g_avg * g_avg + eps);
} else {
avg = sqrtg(s_avg + eps);
};

g /= avg;

if (cfg.has_momentum) {
m = m * momentum_ + g;
g = m * lr;
} else {
g *= lr;
}

if (cfg.weight_decay_type == Decoupled) {
g += weight_decay * lr * p;
}

square_avg[i] = s_avg;
grad_avg[i] = g_avg;
momentum[i] = m;
param[i] -= g;
}

s_avg += (one - alpha) * (g * g - s_avg);

T avg;

if (cfg.centered) {
// ga = a * ga + (1 - a) * g
g_avg += (one - alpha) * (g - g_avg);
avg = sqrtg(s_avg - g_avg * g_avg + eps);
} else {
avg = sqrtg(s_avg + eps);
};

g /= avg;

if (cfg.has_momentum) {
m = m * momentum_ + g;
g = m * lr;
} else {
g *= lr;
}

if (cfg.weight_decay_type == Decoupled) {
g += weight_decay * lr * p;
}

square_avg[i] = s_avg;
grad_avg[i] = g_avg;
momentum[i] = m;
param[i] -= g;
}

#define RMSPROP(TYPENAME, FN) \
Expand Down
54 changes: 25 additions & 29 deletions src/optim/sgd/sgd.cu
Original file line number Diff line number Diff line change
Expand Up @@ -28,40 +28,36 @@ __device__ void sgd_update(
T* velocity,
const T* grad
) {
unsigned int i = blockIdx.x * blockDim.x + threadIdx.x;

if (i >= numel) {
return;
}

T weight_decay = cfg.weight_decay;
T lr = cfg.lr;
T momentum = cfg.momentum;

T p = param[i];
T g = grad[i];
T v = velocity[i];

if (cfg.weight_decay_type == L2) {
g += weight_decay * p;
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) {
T p = param[i];
T g = grad[i];
T v = velocity[i];

if (cfg.weight_decay_type == L2) {
g += weight_decay * p;
}

if (cfg.momentum_type == Classic) {
v = g + momentum * v;
g = v * lr;
} else if (cfg.momentum_type == Nesterov) {
v = g + momentum * v;
g = (g + momentum * v) * lr;
} else {
g *= lr;
}

if (cfg.weight_decay_type == Decoupled) {
g += weight_decay * lr * p;
}

velocity[i] = v;
param[i] -= g;
}

if (cfg.momentum_type == Classic) {
v = g + momentum * v;
g = v * lr;
} else if (cfg.momentum_type == Nesterov) {
v = g + momentum * v;
g = (g + momentum * v) * lr;
} else {
g *= lr;
}

if (cfg.weight_decay_type == Decoupled) {
g += weight_decay * lr * p;
}

velocity[i] = v;
param[i] -= g;
}

#define SGD(TYPENAME, FN) \
Expand Down
Loading

0 comments on commit bbe3b92

Please sign in to comment.