Skip to content

Commit

Permalink
Adding comments to nn layers
Browse files Browse the repository at this point in the history
  • Loading branch information
coreylowman committed Aug 29, 2023
1 parent 1dfae77 commit 275afae
Show file tree
Hide file tree
Showing 42 changed files with 541 additions and 16 deletions.
28 changes: 25 additions & 3 deletions dfdx-nn-core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@ use dfdx::{
shapes::HasShape,
};

/// Mutable & Immutable forward of `Input` that produces [Module::Output].
pub trait Module<X> {
/// The type that this unit produces given `Input`.
type Output;
type Error: std::fmt::Debug;

Expand All @@ -25,6 +27,9 @@ pub trait Module<X> {
}
}

/// An error indicating that a parameter was not used in gradient
/// computation, and was therefore not present in [Gradients]
/// during an update.
#[derive(Debug)]
pub enum OptimizerUpdateError<Err> {
UnusedTensors(Vec<UniqueId>),
Expand All @@ -43,6 +48,7 @@ impl<Err: std::fmt::Display> std::fmt::Display for OptimizerUpdateError<Err> {
#[cfg(feature = "std")]
impl<Err: std::fmt::Debug + std::fmt::Display> std::error::Error for OptimizerUpdateError<Err> {}

/// Something that can update both tensors and a [UpdateParams]. At minimum [Optimizer::update_tensor()] must be implemented.
pub trait Optimizer<M, E: Dtype, D: Device<E>>: Sized {
fn update_tensor<S: Shape>(
&mut self,
Expand Down Expand Up @@ -71,6 +77,7 @@ pub trait Optimizer<M, E: Dtype, D: Device<E>>: Sized {
}
}

/// Something that can be constructed on a device as a certain dtype.
pub trait BuildOnDevice<E: Dtype, D: Device<E>>: Clone {
type Built: Clone + std::fmt::Debug;
fn build_on_device(&self, device: &D) -> Self::Built {
Expand All @@ -79,13 +86,15 @@ pub trait BuildOnDevice<E: Dtype, D: Device<E>>: Clone {
fn try_build_on_device(&self, device: &D) -> Result<Self::Built, D::Err>;
}

/// Something that can have all of its parameters reset to a specific state (may be random or not random).
pub trait ResetParams<E: Dtype, D: Device<E>> {
fn reset_params(&mut self) {
self.try_reset_params().unwrap()
}
fn try_reset_params(&mut self) -> Result<(), D::Err>;
}

/// Something that can have it's params updated with an [Optimizer] and a set of [Gradients].
pub trait UpdateParams<E: Dtype, D: Device<E>> {
fn update_params<M, Optim: Optimizer<M, E, D>>(
&mut self,
Expand All @@ -104,6 +113,7 @@ pub trait UpdateParams<E: Dtype, D: Device<E>> {
) -> Result<(), D::Err>;
}

/// Something that can allocate a [Gradients] object or zero out the [Gradients] object.
pub trait ZeroGrads<E: Dtype, D: Device<E>> {
fn zero_grads(&self, grads: &mut Gradients<E, D>) {
self.try_zero_grads(grads).unwrap()
Expand All @@ -121,6 +131,7 @@ pub trait ZeroGrads<E: Dtype, D: Device<E>> {
}
}

/// Something that can be saved to a .safetensors file.
pub trait SaveSafeTensors {
fn save_safetensors<P: AsRef<std::path::Path>>(
&self,
Expand Down Expand Up @@ -149,6 +160,7 @@ pub trait SaveSafeTensors {
);
}

/// Something that can be loaded from a .safetensors file.
pub trait LoadSafeTensors {
fn load_safetensors<P: AsRef<std::path::Path>>(
&mut self,
Expand Down Expand Up @@ -241,16 +253,26 @@ unit_safetensors!(i64);
unit_safetensors!(isize);
unit_safetensors!(usize);

/// Extension method that calls [BuildOnDevice] and then [ResetParams].
pub trait BuildModuleExt<M>: Sized {
fn build_module_ext<E: Dtype>(&self, m: M) -> M::Built
where
M: BuildOnDevice<E, Self>,
M::Built: ResetParams<E, Self>,
Self: Device<E>,
{
let mut module = m.build_on_device(self);
module.reset_params();
module
self.try_build_module_ext(m).unwrap()
}

fn try_build_module_ext<E: Dtype>(&self, m: M) -> Result<M::Built, Self::Err>
where
M: BuildOnDevice<E, Self>,
M::Built: ResetParams<E, Self>,
Self: Device<E>,
{
let mut module = m.try_build_on_device(self)?;
module.try_reset_params()?;
Ok(module)
}
}
impl<D, M> BuildModuleExt<M> for D {}
1 change: 1 addition & 0 deletions dfdx-nn/src/abs.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use dfdx::prelude::{Device, Dtype, Shape, Tape, Tensor};

/// Calls [dfdx::tensor_ops::abs()]
#[derive(Default, Debug, Clone, Copy, crate::CustomModule)]
pub struct Abs;
impl<S: Shape, E: Dtype, D: Device<E>, T: Tape<E, D>> crate::Module<Tensor<S, E, D, T>> for Abs {
Expand Down
18 changes: 18 additions & 0 deletions dfdx-nn/src/add_into.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,24 @@ use dfdx::{
tensor_ops::{Device, TryAdd},
};

/// Add inputs together into a single tensor. `T` should be a tuple
//// where every element of the tuple has the same output type
///
/// This provides a utility for networks where multiple inputs are needed
///
/// Generics:
/// - `T` the module to add the outputs together of
///
/// # Examples
/// ```rust
/// # use dfdx::prelude::*;
/// # let dev: Cpu = Default::default();
/// type Model = AddInto<(Linear<2, 5>, Linear<3, 5>)>;
/// let model = dev.build_module::<Model, f32>();
/// let a: Tensor<Rank1<2>, f32, _> = dev.zeros();
/// let b: Tensor<Rank1<3>, f32, _> = dev.zeros();
/// let _: Tensor<Rank1<5>, f32, _> = model.forward((a, b));
/// ```
#[derive(
Debug, Default, Clone, ResetParams, ZeroGrads, UpdateParams, LoadSafeTensors, SaveSafeTensors,
)]
Expand Down
42 changes: 42 additions & 0 deletions dfdx-nn/src/batch_norm1d.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,43 @@
use crate::{LoadSafeTensors, SaveSafeTensors, UpdateParams, ZeroGrads};
use dfdx::prelude::*;

/// Batch normalization for sequences as described in
/// [Batch Normalization: Accelerating Deep Network Training
/// by Reducing Internal Covariate Shift](https://arxiv.org/abs/1502.03167)
///
/// Generics:
///
/// - `C` the size of the dimension to reduce. Both for 2d tensors (of the form <BATCH_SIZE, C>)
/// as well as 3d tensors (of the form <BATCH_SIZE, C, SEQUENCE_LENGTH>), this is the 1st dimension.
///
/// # Training vs Inference
///
/// BatchNorm1D supports the following cases (see sections below for more details):
/// 1. **Training**: [crate::Module::forward_mut()] and [OwnedTape] on the input tensor
/// 2. **Inference**: [crate::Module::forward()] and [NoneTape] on the input tensor.
///
/// Examples:
/// ```rust
/// # use dfdx::prelude::*;
/// # let dev: Cpu = Default::default();
/// type Model = BatchNorm1D<3>;
/// let bn = dev.build_module::<Model, f32>();
/// let _ = bn.forward(dev.zeros::<Rank2<4, 3>>());
/// let _ = bn.forward(dev.zeros::<Rank3<4, 3, 2>>());
/// ```
///
/// ### Training
/// - Running statistics: updated with momentum
/// - Normalization: calculated using batch stats
///
/// ### Inference
/// - Running statistics: **not** updated
/// - Normalization: calculated using running stats
#[derive(Default, Clone, Copy, Debug)]
#[repr(transparent)]
pub struct BatchNorm1DConfig<C: Dim>(pub C);

/// Compile time sugar alias around [BatchNorm1DConfig]
pub type BatchNorm1DConstConfig<const C: usize> = BatchNorm1DConfig<Const<C>>;

impl<C: Dim, E: Dtype, D: Device<E>> crate::BuildOnDevice<E, D> for BatchNorm1DConfig<C> {
Expand All @@ -21,20 +54,29 @@ impl<C: Dim, E: Dtype, D: Device<E>> crate::BuildOnDevice<E, D> for BatchNorm1DC
}
}

/// See [BatchNorm1DConfig].
#[derive(Clone, Debug, UpdateParams, ZeroGrads, SaveSafeTensors, LoadSafeTensors)]
pub struct BatchNorm1D<C: Dim, Elem: Dtype, Dev: Device<Elem>> {
/// Scale for affine transform. Defaults to 1.0
#[param]
#[serialize]
pub scale: Tensor<(C,), Elem, Dev>,
/// Bias for affine transform. Defaults to 0.0
#[param]
#[serialize]
pub bias: Tensor<(C,), Elem, Dev>,
/// Spatial mean that is updated during training. Defaults to 0.0
#[serialize]
pub running_mean: Tensor<(C,), Elem, Dev>,
/// Spatial variance that is updated during training. Defaults to 1.0
#[serialize]
pub running_var: Tensor<(C,), Elem, Dev>,
/// Added to variance before taking sqrt for numerical stability. Defaults to 1e-5
#[serialize]
pub epsilon: f64,
/// Controls exponential moving average of running stats. Defaults to 0.1
///
/// `running_stat * (1.0 - momentum) + stat * momentum`.
#[serialize]
pub momentum: f64,
}
Expand Down
36 changes: 36 additions & 0 deletions dfdx-nn/src/batch_norm2d.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,45 @@
use crate::{LoadSafeTensors, SaveSafeTensors, UpdateParams, ZeroGrads};
use dfdx::prelude::*;

/// Batch normalization for images as described in
/// [Batch Normalization: Accelerating Deep Network Training
/// by Reducing Internal Covariate Shift](https://arxiv.org/abs/1502.03167)
///
/// Generics:
///
/// - `C` the size of the spatial dimension to reduce. For 3d tensors this is the 0th
/// dimension. For 4d tensors, this is the 1st dimension.
///
/// # Training vs Inference
///
/// BatchNorm2D supports the following cases (see sections below for more details):
/// 1. **Training**: [ModuleMut] and [OwnedTape] on the input tensor
/// 2. **Inference**: [Module] and [NoneTape] on the input tensor.
///
/// *NOTE: ModuleMut/NoneTape, and Module/OwnedTape will fail to compile.*
///
/// Examples:
/// ```rust
/// # use dfdx::prelude::*;
/// # let dev: Cpu = Default::default();
/// type Model = BatchNorm2D<3>;
/// let bn = dev.build_module::<Model, f32>();
/// let _ = bn.forward(dev.zeros::<Rank3<3, 2, 2>>());
/// let _ = bn.forward(dev.zeros::<Rank4<4, 3, 2, 2>>());
/// ```
///
/// ### Training
/// - Running statistics: updated with momentum
/// - Normalization: calculated using batch stats
///
/// ### Inference
/// - Running statistics: **not** updated
/// - Normalization: calculated using running stats
#[derive(Default, Clone, Copy, Debug)]
#[repr(transparent)]
pub struct BatchNorm2DConfig<C: Dim>(pub C);

/// Compile time sugar alias around [BatchNorm2DConfig]
pub type BatchNorm2DConstConfig<const C: usize> = BatchNorm2DConfig<Const<C>>;

impl<C: Dim, E: Dtype, D: Device<E>> crate::BuildOnDevice<E, D> for BatchNorm2DConfig<C> {
Expand All @@ -21,6 +56,7 @@ impl<C: Dim, E: Dtype, D: Device<E>> crate::BuildOnDevice<E, D> for BatchNorm2DC
}
}

/// See [BatchNorm2DConfig]
#[derive(Clone, Debug, UpdateParams, ZeroGrads, SaveSafeTensors, LoadSafeTensors)]
pub struct BatchNorm2D<C: Dim, Elem: Dtype, Dev: Device<Elem>> {
#[param]
Expand Down
20 changes: 20 additions & 0 deletions dfdx-nn/src/bias1d.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,29 @@ use dfdx::{

use crate::*;

/// Adds a learnable 1d bias to 2d and 3d inputs.
///
/// Example:
/// ```rust
/// # use dfdx::prelude::*;
/// # let dev: Cpu = Default::default();
/// const NUM_CHANS: usize = 5;
/// type Model = Bias1D<NUM_CHANS>;
/// let model = dev.build_module::<Model, f32>();
///
/// // 3d input
/// let x: Tensor<Rank3<NUM_CHANS, 2, 3>, f32, _> = dev.sample_normal();
/// model.forward(x);
///
/// // 4d input
/// let x: Tensor<Rank4<10, NUM_CHANS, 2, 3>, f32, _> = dev.sample_normal();
/// model.forward(x);
/// ```
#[derive(Default, Clone, Copy, Debug)]
#[repr(transparent)]
pub struct Bias1DConfig<I: Dim>(pub I);

/// Compile time sugar alias around [Bias1DConfig]
pub type Bias1DConstConfig<const I: usize> = Bias1DConfig<Const<I>>;

impl<I: Dim, E: Dtype, D: Device<E>> BuildOnDevice<E, D> for Bias1DConfig<I> {
Expand All @@ -21,6 +40,7 @@ impl<I: Dim, E: Dtype, D: Device<E>> BuildOnDevice<E, D> for Bias1DConfig<I> {
}
}

/// See [Bias1DConfig]
#[derive(Clone, Debug, UpdateParams, ZeroGrads, SaveSafeTensors, LoadSafeTensors)]
pub struct Bias1D<I: Dim, Elem: Dtype, Dev: Device<Elem>> {
#[param]
Expand Down
20 changes: 20 additions & 0 deletions dfdx-nn/src/bias2d.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,29 @@ use dfdx::{

use crate::*;

/// Adds a learnable 1d bias to 3d and 4d inputs.
///
/// Example:
/// ```rust
/// # use dfdx::prelude::*;
/// # let dev: Cpu = Default::default();
/// const NUM_CHANS: usize = 5;
/// type Model = Bias2D<NUM_CHANS>;
/// let model = dev.build_module::<Model, f32>();
///
/// // 3d input
/// let x: Tensor<Rank3<NUM_CHANS, 2, 3>, f32, _> = dev.sample_normal();
/// model.forward(x);
///
/// // 4d input
/// let x: Tensor<Rank4<10, NUM_CHANS, 2, 3>, f32, _> = dev.sample_normal();
/// model.forward(x);
/// ```
#[derive(Default, Clone, Copy, Debug)]
#[repr(transparent)]
pub struct Bias2DConfig<I: Dim>(pub I);

/// Compile time sugar alias around [Bias2DConfig]
pub type Bias2DConstConfig<const I: usize> = Bias2DConfig<Const<I>>;

impl<I: Dim, E: Dtype, D: Device<E>> BuildOnDevice<E, D> for Bias2DConfig<I> {
Expand All @@ -21,6 +40,7 @@ impl<I: Dim, E: Dtype, D: Device<E>> BuildOnDevice<E, D> for Bias2DConfig<I> {
}
}

/// See [Bias2DConfig]
#[derive(Clone, Debug, UpdateParams, ZeroGrads, SaveSafeTensors, LoadSafeTensors)]
pub struct Bias2D<I: Dim, Elem: Dtype, Dev: Device<Elem>> {
#[param]
Expand Down
35 changes: 35 additions & 0 deletions dfdx-nn/src/conv2d.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,39 @@ use dfdx::{

use crate::*;

/// **Requires Nightly** Performs *unbiased* 2d convolutions on 3d and 4d images.
///
/// **Pytorch Equivalent**: `torch.nn.Conv2d(..., bias=False)`
///
/// Example usage:
/// ```rust
/// # use dfdx_nn::Conv2DConfig;
/// # use dfdx::shapes::Const;
/// // compile time channels/kernel
/// let m: Conv2DConfig<Const<3>, Const<5>, Const<3>> = Default::default();
/// // runtime channels/kernel
/// let m: Conv2DConfig<usize, usize, usize> = Conv2DConfig {
/// in_chan: 3,
/// out_chan: 5,
/// kernel_size: 3,
/// ..Default::default()
/// };
/// ```
///
/// To create a biased conv, combine with [crate::Bias2D].
///
/// Generics:
/// - `InChan`: The number of input channels in an image.
/// - `OutChan`: The number of channels in the output of the layer.
/// - `KernelSize`: The size of the kernel applied to both width and height of the images.
/// - `Stride`: How far to move the kernel each step. Defaults to `Const<1>`
/// - `Padding`: How much zero padding to add around the images. Defaults to `Const<0>`.
/// - `Dilation`: Controls the spacing between kernel points. Defaults to `Const<1>`.
/// - `Groups`: Controls the connections between inputs and outputs.
/// `InChan` and `OutChan` must both be divisible by `Groups`.
///
/// See [conv animations](https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md) for helpful
/// visualization of all of these parameters.
#[derive(Debug, Default, Clone, Copy)]
pub struct Conv2DConfig<
InChan: Dim,
Expand All @@ -25,6 +58,7 @@ pub struct Conv2DConfig<
pub groups: Groups,
}

/// Compile time sugar alias around [Conv2DConfig]
pub type Conv2DConstConfig<
const IN_CHAN: usize,
const OUT_CHAN: usize,
Expand Down Expand Up @@ -70,6 +104,7 @@ where
}
}

/// The module built with [Conv2DConfig]. See [Conv2DConfig] for usage.
#[derive(Debug, Clone, UpdateParams, ZeroGrads, SaveSafeTensors, LoadSafeTensors)]
pub struct Conv2D<InChan, OutChan, KernelSize, Stride, Padding, Dilation, Groups, Elem, Dev>
where
Expand Down
Loading

0 comments on commit 275afae

Please sign in to comment.