From 275afae57c2b14c747e722ff9952164d23f80aad Mon Sep 17 00:00:00 2001 From: Corey Lowman Date: Tue, 29 Aug 2023 16:50:12 -0700 Subject: [PATCH] Adding comments to nn layers --- dfdx-nn-core/src/lib.rs | 28 +++++++++++++++-- dfdx-nn/src/abs.rs | 1 + dfdx-nn/src/add_into.rs | 18 +++++++++++ dfdx-nn/src/batch_norm1d.rs | 42 +++++++++++++++++++++++++ dfdx-nn/src/batch_norm2d.rs | 36 ++++++++++++++++++++++ dfdx-nn/src/bias1d.rs | 20 ++++++++++++ dfdx-nn/src/bias2d.rs | 20 ++++++++++++ dfdx-nn/src/conv2d.rs | 35 +++++++++++++++++++++ dfdx-nn/src/conv_trans2d.rs | 17 ++++++++++ dfdx-nn/src/cos.rs | 2 +- dfdx-nn/src/dropout.rs | 27 ++++++++++++++++ dfdx-nn/src/embedding.rs | 32 +++++++++++++++++++ dfdx-nn/src/exp.rs | 2 +- dfdx-nn/src/flatten2d.rs | 1 + dfdx-nn/src/gelu.rs | 4 +-- dfdx-nn/src/generalized_add.rs | 17 ++++++++++ dfdx-nn/src/layer_norm1d.rs | 20 ++++++++++++ dfdx-nn/src/leaky_relu.rs | 1 + dfdx-nn/src/linear.rs | 19 ++++++++++++ dfdx-nn/src/ln.rs | 2 +- dfdx-nn/src/log_softmax.rs | 2 +- dfdx-nn/src/matmul.rs | 17 +++++++++- dfdx-nn/src/multi_head_attention.rs | 14 +++++++++ dfdx-nn/src/pool_2d_avg.rs | 8 +++++ dfdx-nn/src/pool_2d_max.rs | 8 +++++ dfdx-nn/src/pool_2d_min.rs | 8 +++++ dfdx-nn/src/pool_global_avg.rs | 15 +++++++++ dfdx-nn/src/pool_global_max.rs | 15 +++++++++ dfdx-nn/src/pool_global_min.rs | 15 +++++++++ dfdx-nn/src/prelu.rs | 2 ++ dfdx-nn/src/prelu1d.rs | 2 ++ dfdx-nn/src/relu.rs | 2 +- dfdx-nn/src/reshape.rs | 10 ++++++ dfdx-nn/src/residual_add.rs | 16 ++++++++++ dfdx-nn/src/sigmoid.rs | 2 +- dfdx-nn/src/sin.rs | 2 +- dfdx-nn/src/softmax.rs | 2 +- dfdx-nn/src/split_into.rs | 20 ++++++++++++ dfdx-nn/src/sqrt.rs | 1 + dfdx-nn/src/square.rs | 2 +- dfdx-nn/src/tanh.rs | 2 +- dfdx-nn/src/transformer.rs | 48 +++++++++++++++++++++++++++++ 42 files changed, 541 insertions(+), 16 deletions(-) diff --git a/dfdx-nn-core/src/lib.rs b/dfdx-nn-core/src/lib.rs index 8913baf9..aea56722 100644 --- a/dfdx-nn-core/src/lib.rs +++ b/dfdx-nn-core/src/lib.rs @@ -6,7 +6,9 @@ use dfdx::{ shapes::HasShape, }; +/// Mutable & Immutable forward of `Input` that produces [Module::Output]. pub trait Module { + /// The type that this unit produces given `Input`. type Output; type Error: std::fmt::Debug; @@ -25,6 +27,9 @@ pub trait Module { } } +/// An error indicating that a parameter was not used in gradient +/// computation, and was therefore not present in [Gradients] +/// during an update. #[derive(Debug)] pub enum OptimizerUpdateError { UnusedTensors(Vec), @@ -43,6 +48,7 @@ impl std::fmt::Display for OptimizerUpdateError { #[cfg(feature = "std")] impl std::error::Error for OptimizerUpdateError {} +/// Something that can update both tensors and a [UpdateParams]. At minimum [Optimizer::update_tensor()] must be implemented. pub trait Optimizer>: Sized { fn update_tensor( &mut self, @@ -71,6 +77,7 @@ pub trait Optimizer>: Sized { } } +/// Something that can be constructed on a device as a certain dtype. pub trait BuildOnDevice>: Clone { type Built: Clone + std::fmt::Debug; fn build_on_device(&self, device: &D) -> Self::Built { @@ -79,6 +86,7 @@ pub trait BuildOnDevice>: Clone { fn try_build_on_device(&self, device: &D) -> Result; } +/// Something that can have all of its parameters reset to a specific state (may be random or not random). pub trait ResetParams> { fn reset_params(&mut self) { self.try_reset_params().unwrap() @@ -86,6 +94,7 @@ pub trait ResetParams> { fn try_reset_params(&mut self) -> Result<(), D::Err>; } +/// Something that can have it's params updated with an [Optimizer] and a set of [Gradients]. pub trait UpdateParams> { fn update_params>( &mut self, @@ -104,6 +113,7 @@ pub trait UpdateParams> { ) -> Result<(), D::Err>; } +/// Something that can allocate a [Gradients] object or zero out the [Gradients] object. pub trait ZeroGrads> { fn zero_grads(&self, grads: &mut Gradients) { self.try_zero_grads(grads).unwrap() @@ -121,6 +131,7 @@ pub trait ZeroGrads> { } } +/// Something that can be saved to a .safetensors file. pub trait SaveSafeTensors { fn save_safetensors>( &self, @@ -149,6 +160,7 @@ pub trait SaveSafeTensors { ); } +/// Something that can be loaded from a .safetensors file. pub trait LoadSafeTensors { fn load_safetensors>( &mut self, @@ -241,6 +253,7 @@ unit_safetensors!(i64); unit_safetensors!(isize); unit_safetensors!(usize); +/// Extension method that calls [BuildOnDevice] and then [ResetParams]. pub trait BuildModuleExt: Sized { fn build_module_ext(&self, m: M) -> M::Built where @@ -248,9 +261,18 @@ pub trait BuildModuleExt: Sized { M::Built: ResetParams, Self: Device, { - let mut module = m.build_on_device(self); - module.reset_params(); - module + self.try_build_module_ext(m).unwrap() + } + + fn try_build_module_ext(&self, m: M) -> Result + where + M: BuildOnDevice, + M::Built: ResetParams, + Self: Device, + { + let mut module = m.try_build_on_device(self)?; + module.try_reset_params()?; + Ok(module) } } impl BuildModuleExt for D {} diff --git a/dfdx-nn/src/abs.rs b/dfdx-nn/src/abs.rs index f5a6fac3..98629c74 100644 --- a/dfdx-nn/src/abs.rs +++ b/dfdx-nn/src/abs.rs @@ -1,5 +1,6 @@ use dfdx::prelude::{Device, Dtype, Shape, Tape, Tensor}; +/// Calls [dfdx::tensor_ops::abs()] #[derive(Default, Debug, Clone, Copy, crate::CustomModule)] pub struct Abs; impl, T: Tape> crate::Module> for Abs { diff --git a/dfdx-nn/src/add_into.rs b/dfdx-nn/src/add_into.rs index 3bd7f3ea..ca1c069a 100644 --- a/dfdx-nn/src/add_into.rs +++ b/dfdx-nn/src/add_into.rs @@ -5,6 +5,24 @@ use dfdx::{ tensor_ops::{Device, TryAdd}, }; +/// Add inputs together into a single tensor. `T` should be a tuple +//// where every element of the tuple has the same output type +/// +/// This provides a utility for networks where multiple inputs are needed +/// +/// Generics: +/// - `T` the module to add the outputs together of +/// +/// # Examples +/// ```rust +/// # use dfdx::prelude::*; +/// # let dev: Cpu = Default::default(); +/// type Model = AddInto<(Linear<2, 5>, Linear<3, 5>)>; +/// let model = dev.build_module::(); +/// let a: Tensor, f32, _> = dev.zeros(); +/// let b: Tensor, f32, _> = dev.zeros(); +/// let _: Tensor, f32, _> = model.forward((a, b)); +/// ``` #[derive( Debug, Default, Clone, ResetParams, ZeroGrads, UpdateParams, LoadSafeTensors, SaveSafeTensors, )] diff --git a/dfdx-nn/src/batch_norm1d.rs b/dfdx-nn/src/batch_norm1d.rs index b130aff6..d15313cd 100644 --- a/dfdx-nn/src/batch_norm1d.rs +++ b/dfdx-nn/src/batch_norm1d.rs @@ -1,10 +1,43 @@ use crate::{LoadSafeTensors, SaveSafeTensors, UpdateParams, ZeroGrads}; use dfdx::prelude::*; +/// Batch normalization for sequences as described in +/// [Batch Normalization: Accelerating Deep Network Training +/// by Reducing Internal Covariate Shift](https://arxiv.org/abs/1502.03167) +/// +/// Generics: +/// +/// - `C` the size of the dimension to reduce. Both for 2d tensors (of the form ) +/// as well as 3d tensors (of the form ), this is the 1st dimension. +/// +/// # Training vs Inference +/// +/// BatchNorm1D supports the following cases (see sections below for more details): +/// 1. **Training**: [crate::Module::forward_mut()] and [OwnedTape] on the input tensor +/// 2. **Inference**: [crate::Module::forward()] and [NoneTape] on the input tensor. +/// +/// Examples: +/// ```rust +/// # use dfdx::prelude::*; +/// # let dev: Cpu = Default::default(); +/// type Model = BatchNorm1D<3>; +/// let bn = dev.build_module::(); +/// let _ = bn.forward(dev.zeros::>()); +/// let _ = bn.forward(dev.zeros::>()); +/// ``` +/// +/// ### Training +/// - Running statistics: updated with momentum +/// - Normalization: calculated using batch stats +/// +/// ### Inference +/// - Running statistics: **not** updated +/// - Normalization: calculated using running stats #[derive(Default, Clone, Copy, Debug)] #[repr(transparent)] pub struct BatchNorm1DConfig(pub C); +/// Compile time sugar alias around [BatchNorm1DConfig] pub type BatchNorm1DConstConfig = BatchNorm1DConfig>; impl> crate::BuildOnDevice for BatchNorm1DConfig { @@ -21,20 +54,29 @@ impl> crate::BuildOnDevice for BatchNorm1DC } } +/// See [BatchNorm1DConfig]. #[derive(Clone, Debug, UpdateParams, ZeroGrads, SaveSafeTensors, LoadSafeTensors)] pub struct BatchNorm1D> { + /// Scale for affine transform. Defaults to 1.0 #[param] #[serialize] pub scale: Tensor<(C,), Elem, Dev>, + /// Bias for affine transform. Defaults to 0.0 #[param] #[serialize] pub bias: Tensor<(C,), Elem, Dev>, + /// Spatial mean that is updated during training. Defaults to 0.0 #[serialize] pub running_mean: Tensor<(C,), Elem, Dev>, + /// Spatial variance that is updated during training. Defaults to 1.0 #[serialize] pub running_var: Tensor<(C,), Elem, Dev>, + /// Added to variance before taking sqrt for numerical stability. Defaults to 1e-5 #[serialize] pub epsilon: f64, + /// Controls exponential moving average of running stats. Defaults to 0.1 + /// + /// `running_stat * (1.0 - momentum) + stat * momentum`. #[serialize] pub momentum: f64, } diff --git a/dfdx-nn/src/batch_norm2d.rs b/dfdx-nn/src/batch_norm2d.rs index 54934538..78bbacb0 100644 --- a/dfdx-nn/src/batch_norm2d.rs +++ b/dfdx-nn/src/batch_norm2d.rs @@ -1,10 +1,45 @@ use crate::{LoadSafeTensors, SaveSafeTensors, UpdateParams, ZeroGrads}; use dfdx::prelude::*; +/// Batch normalization for images as described in +/// [Batch Normalization: Accelerating Deep Network Training +/// by Reducing Internal Covariate Shift](https://arxiv.org/abs/1502.03167) +/// +/// Generics: +/// +/// - `C` the size of the spatial dimension to reduce. For 3d tensors this is the 0th +/// dimension. For 4d tensors, this is the 1st dimension. +/// +/// # Training vs Inference +/// +/// BatchNorm2D supports the following cases (see sections below for more details): +/// 1. **Training**: [ModuleMut] and [OwnedTape] on the input tensor +/// 2. **Inference**: [Module] and [NoneTape] on the input tensor. +/// +/// *NOTE: ModuleMut/NoneTape, and Module/OwnedTape will fail to compile.* +/// +/// Examples: +/// ```rust +/// # use dfdx::prelude::*; +/// # let dev: Cpu = Default::default(); +/// type Model = BatchNorm2D<3>; +/// let bn = dev.build_module::(); +/// let _ = bn.forward(dev.zeros::>()); +/// let _ = bn.forward(dev.zeros::>()); +/// ``` +/// +/// ### Training +/// - Running statistics: updated with momentum +/// - Normalization: calculated using batch stats +/// +/// ### Inference +/// - Running statistics: **not** updated +/// - Normalization: calculated using running stats #[derive(Default, Clone, Copy, Debug)] #[repr(transparent)] pub struct BatchNorm2DConfig(pub C); +/// Compile time sugar alias around [BatchNorm2DConfig] pub type BatchNorm2DConstConfig = BatchNorm2DConfig>; impl> crate::BuildOnDevice for BatchNorm2DConfig { @@ -21,6 +56,7 @@ impl> crate::BuildOnDevice for BatchNorm2DC } } +/// See [BatchNorm2DConfig] #[derive(Clone, Debug, UpdateParams, ZeroGrads, SaveSafeTensors, LoadSafeTensors)] pub struct BatchNorm2D> { #[param] diff --git a/dfdx-nn/src/bias1d.rs b/dfdx-nn/src/bias1d.rs index c3196343..a58b4960 100644 --- a/dfdx-nn/src/bias1d.rs +++ b/dfdx-nn/src/bias1d.rs @@ -6,10 +6,29 @@ use dfdx::{ use crate::*; +/// Adds a learnable 1d bias to 2d and 3d inputs. +/// +/// Example: +/// ```rust +/// # use dfdx::prelude::*; +/// # let dev: Cpu = Default::default(); +/// const NUM_CHANS: usize = 5; +/// type Model = Bias1D; +/// let model = dev.build_module::(); +/// +/// // 3d input +/// let x: Tensor, f32, _> = dev.sample_normal(); +/// model.forward(x); +/// +/// // 4d input +/// let x: Tensor, f32, _> = dev.sample_normal(); +/// model.forward(x); +/// ``` #[derive(Default, Clone, Copy, Debug)] #[repr(transparent)] pub struct Bias1DConfig(pub I); +/// Compile time sugar alias around [Bias1DConfig] pub type Bias1DConstConfig = Bias1DConfig>; impl> BuildOnDevice for Bias1DConfig { @@ -21,6 +40,7 @@ impl> BuildOnDevice for Bias1DConfig { } } +/// See [Bias1DConfig] #[derive(Clone, Debug, UpdateParams, ZeroGrads, SaveSafeTensors, LoadSafeTensors)] pub struct Bias1D> { #[param] diff --git a/dfdx-nn/src/bias2d.rs b/dfdx-nn/src/bias2d.rs index 591517e7..72425645 100644 --- a/dfdx-nn/src/bias2d.rs +++ b/dfdx-nn/src/bias2d.rs @@ -6,10 +6,29 @@ use dfdx::{ use crate::*; +/// Adds a learnable 1d bias to 3d and 4d inputs. +/// +/// Example: +/// ```rust +/// # use dfdx::prelude::*; +/// # let dev: Cpu = Default::default(); +/// const NUM_CHANS: usize = 5; +/// type Model = Bias2D; +/// let model = dev.build_module::(); +/// +/// // 3d input +/// let x: Tensor, f32, _> = dev.sample_normal(); +/// model.forward(x); +/// +/// // 4d input +/// let x: Tensor, f32, _> = dev.sample_normal(); +/// model.forward(x); +/// ``` #[derive(Default, Clone, Copy, Debug)] #[repr(transparent)] pub struct Bias2DConfig(pub I); +/// Compile time sugar alias around [Bias2DConfig] pub type Bias2DConstConfig = Bias2DConfig>; impl> BuildOnDevice for Bias2DConfig { @@ -21,6 +40,7 @@ impl> BuildOnDevice for Bias2DConfig { } } +/// See [Bias2DConfig] #[derive(Clone, Debug, UpdateParams, ZeroGrads, SaveSafeTensors, LoadSafeTensors)] pub struct Bias2D> { #[param] diff --git a/dfdx-nn/src/conv2d.rs b/dfdx-nn/src/conv2d.rs index c980c7e0..11fe0025 100644 --- a/dfdx-nn/src/conv2d.rs +++ b/dfdx-nn/src/conv2d.rs @@ -6,6 +6,39 @@ use dfdx::{ use crate::*; +/// **Requires Nightly** Performs *unbiased* 2d convolutions on 3d and 4d images. +/// +/// **Pytorch Equivalent**: `torch.nn.Conv2d(..., bias=False)` +/// +/// Example usage: +/// ```rust +/// # use dfdx_nn::Conv2DConfig; +/// # use dfdx::shapes::Const; +/// // compile time channels/kernel +/// let m: Conv2DConfig, Const<5>, Const<3>> = Default::default(); +/// // runtime channels/kernel +/// let m: Conv2DConfig = Conv2DConfig { +/// in_chan: 3, +/// out_chan: 5, +/// kernel_size: 3, +/// ..Default::default() +/// }; +/// ``` +/// +/// To create a biased conv, combine with [crate::Bias2D]. +/// +/// Generics: +/// - `InChan`: The number of input channels in an image. +/// - `OutChan`: The number of channels in the output of the layer. +/// - `KernelSize`: The size of the kernel applied to both width and height of the images. +/// - `Stride`: How far to move the kernel each step. Defaults to `Const<1>` +/// - `Padding`: How much zero padding to add around the images. Defaults to `Const<0>`. +/// - `Dilation`: Controls the spacing between kernel points. Defaults to `Const<1>`. +/// - `Groups`: Controls the connections between inputs and outputs. +/// `InChan` and `OutChan` must both be divisible by `Groups`. +/// +/// See [conv animations](https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md) for helpful +/// visualization of all of these parameters. #[derive(Debug, Default, Clone, Copy)] pub struct Conv2DConfig< InChan: Dim, @@ -25,6 +58,7 @@ pub struct Conv2DConfig< pub groups: Groups, } +/// Compile time sugar alias around [Conv2DConfig] pub type Conv2DConstConfig< const IN_CHAN: usize, const OUT_CHAN: usize, @@ -70,6 +104,7 @@ where } } +/// The module built with [Conv2DConfig]. See [Conv2DConfig] for usage. #[derive(Debug, Clone, UpdateParams, ZeroGrads, SaveSafeTensors, LoadSafeTensors)] pub struct Conv2D where diff --git a/dfdx-nn/src/conv_trans2d.rs b/dfdx-nn/src/conv_trans2d.rs index 7dc17cba..c671218d 100644 --- a/dfdx-nn/src/conv_trans2d.rs +++ b/dfdx-nn/src/conv_trans2d.rs @@ -6,6 +6,21 @@ use dfdx::{ use crate::*; +/// **Requires Nightly** Performs *unbiased* 2d deconvolutions on 3d and 4d images. +/// +/// **Pytorch Equivalent**: `torch.nn.ConvTranspose2d(..., bias=False)` +/// +/// To create a biased conv, combine with [crate::Bias2D]. +/// +/// Generics: +/// - `InChan`: The number of input channels in an image. +/// - `OutChan`: The number of channels in the output of the layer. +/// - `KernelSize`: The size of the kernel applied to both width and height of the images. +/// - `Stride`: How far to move the kernel each step. Defaults to `Const<1>` +/// - `Padding`: How much zero padding to add around the images. Defaults to `Const<0>`. +/// - `Dilation`: Controls the spacing between kernel points. Defaults to `Const<1>`. +/// - `Groups`: Controls the connections between inputs and outputs. Defaults to `Const<1>`. +/// `InChan` and `OutChan` must both be divisible by `Groups`. #[derive(Debug, Default, Clone, Copy)] pub struct ConvTrans2DConfig< InChan: Dim, @@ -25,6 +40,7 @@ pub struct ConvTrans2DConfig< pub groups: Groups, } +/// Compile time sugar alias around [ConvTrans2DConfig]. pub type ConvTrans2DConstConfig< const IN_CHAN: usize, const OUT_CHAN: usize, @@ -66,6 +82,7 @@ where } } +/// See [ConvTrans2DConfig]. #[derive(Debug, Clone, UpdateParams, ZeroGrads, SaveSafeTensors, LoadSafeTensors)] pub struct ConvTrans2D where diff --git a/dfdx-nn/src/cos.rs b/dfdx-nn/src/cos.rs index 27dfedea..6f16a566 100644 --- a/dfdx-nn/src/cos.rs +++ b/dfdx-nn/src/cos.rs @@ -1,8 +1,8 @@ use dfdx::prelude::{Device, Dtype, Shape, Tape, Tensor}; +/// Calls [dfdx::tensor_ops::cos()]. #[derive(Default, Debug, Clone, Copy, crate::CustomModule)] pub struct Cos; - impl, T: Tape> crate::Module> for Cos { type Output = Tensor; type Error = D::Err; diff --git a/dfdx-nn/src/dropout.rs b/dfdx-nn/src/dropout.rs index 5ba8fb94..92929210 100644 --- a/dfdx-nn/src/dropout.rs +++ b/dfdx-nn/src/dropout.rs @@ -6,6 +6,21 @@ use dfdx::{ use crate::*; +/// Calls [dfdx::tensor_ops::dropout()] with `p = 1.0 / N` in [Module::forward_mut()], and does nothing in [Module::forward()]. +/// +/// Generics: +/// - `N`: p is set as `1.0 / N` +/// +/// Examples: +/// ```rust +/// # use dfdx::prelude::*; +/// # let dev: Cpu = Default::default(); +/// let mut dropout: DropoutOneIn<2> = Default::default(); +/// let grads = dropout.alloc_grads(); +/// let x: Tensor, f32, _> = dev.ones(); +/// let r = dropout.forward_mut(x.trace(grads)); +/// assert_eq!(r.array(), [[2.0, 0.0, 2.0, 0.0, 2.0], [0.0, 2.0, 0.0, 2.0, 2.0]]); +/// ``` #[derive(Clone, Debug, Default, CustomModule)] pub struct DropoutOneIn; @@ -28,6 +43,18 @@ impl, T: Tape> Module, f32, _> = dev.ones(); +/// let r = dropout.forward_mut(x.trace(grads)); +/// assert_eq!(r.array(), [[2.0, 0.0, 2.0, 0.0, 2.0], [0.0, 2.0, 0.0, 2.0, 2.0]]); +/// ``` #[derive(Clone, Debug, CustomModule)] pub struct Dropout { pub p: f64, diff --git a/dfdx-nn/src/embedding.rs b/dfdx-nn/src/embedding.rs index 9e0f6ed4..4971084a 100644 --- a/dfdx-nn/src/embedding.rs +++ b/dfdx-nn/src/embedding.rs @@ -7,12 +7,43 @@ use dfdx::{ use crate::*; +/// An embedding +/// +/// **Pytorch Equivalent**: `torch.nn.Embedding(...)` +/// +/// Initializes [Self::weight] from the Standard Normal distribution. +/// +/// Generics: +/// - `Vocab`: The size of the vocabulary, inputs integer values must be between +/// 0 and Vocab; +/// - `Model`: The "output" size of vectors & matrices which are the vectors being selected. +/// +/// # Examples +/// `Embedding<5, 2>` can act on vectors with SEQ integer elements (with values between 0 and 4), and results in a SEQ tensor of +/// usually f32 elements being the rows in [Self::weight]. +/// +/// ```rust +/// # use dfdx::prelude::*; +/// # let dev: Cpu = Default::default(); +/// type Model = Embedding<7, 2>; +/// let mut model = dev.build_module::(); +/// // single sequence of ids +/// let inputs: Tensor, usize, _> = dev.zeros(); +/// let _: Tensor<(Const<5>, Const<2>,), f32, _> = model.forward(inputs); +/// // Dynamic sequence of ids +/// let inputs: Tensor<(usize, ), usize, _> = dev.zeros_like(&(5, )); +/// let _: Tensor<(usize, Const<2>,), f32, _> = model.forward(inputs); +/// // batched sequence of ids +/// let inputs: Tensor, usize, _> = dev.zeros(); +/// let _: Tensor<(Const<10>, Const<5>, Const<2>), f32, _> = model.forward(inputs); +/// ``` #[derive(Default, Clone, Copy, Debug)] pub struct EmbeddingConfig { pub vocab: Vocab, pub model: Model, } +/// Compile time sugar alias around [EmbeddingConfig]. pub type EmbeddingConstConfig = EmbeddingConfig, Const>; @@ -25,6 +56,7 @@ impl> BuildOnDevice for EmbeddingCo } } +/// See [EmbeddingConfig]. #[derive(Clone, Debug, UpdateParams, ZeroGrads, SaveSafeTensors, LoadSafeTensors)] pub struct Embedding> { #[param] diff --git a/dfdx-nn/src/exp.rs b/dfdx-nn/src/exp.rs index 669fd2e5..9d1c87a1 100644 --- a/dfdx-nn/src/exp.rs +++ b/dfdx-nn/src/exp.rs @@ -1,8 +1,8 @@ use dfdx::prelude::{Device, Dtype, Shape, Tape, Tensor}; +/// Calls [dfdx::tensor_ops::exp()]. #[derive(Default, Debug, Clone, Copy, crate::CustomModule)] pub struct Exp; - impl, T: Tape> crate::Module> for Exp { type Output = Tensor; type Error = D::Err; diff --git a/dfdx-nn/src/flatten2d.rs b/dfdx-nn/src/flatten2d.rs index 0f207529..4b0fd36a 100644 --- a/dfdx-nn/src/flatten2d.rs +++ b/dfdx-nn/src/flatten2d.rs @@ -8,6 +8,7 @@ use dfdx::{ tensor_ops::{Device, ReshapeTo}, }; +/// **Requires Nightly** Flattens 3d tensors to 1d, and 4d tensors to 2d. #[derive(Debug, Default, Clone, Copy, CustomModule)] pub struct Flatten2D; diff --git a/dfdx-nn/src/gelu.rs b/dfdx-nn/src/gelu.rs index a2eba08f..7ca97a6f 100644 --- a/dfdx-nn/src/gelu.rs +++ b/dfdx-nn/src/gelu.rs @@ -1,8 +1,8 @@ use dfdx::prelude::{Device, Dtype, Shape, Tape, Tensor}; +/// Calls [dfdx::tensor_ops::fast_gelu()]. #[derive(Default, Debug, Clone, Copy, crate::CustomModule)] pub struct FastGeLU; - impl, T: Tape> crate::Module> for FastGeLU { @@ -13,9 +13,9 @@ impl, T: Tape> crate::Module, T: Tape> crate::Module> for AccurateGeLU { diff --git a/dfdx-nn/src/generalized_add.rs b/dfdx-nn/src/generalized_add.rs index 3f0e5cad..33705436 100644 --- a/dfdx-nn/src/generalized_add.rs +++ b/dfdx-nn/src/generalized_add.rs @@ -5,6 +5,23 @@ use dfdx::{ tensor_ops::{Device, TryAdd}, }; +/// A residual connection around two modules: `T(x) + U(x)`, +/// as introduced in [Deep Residual Learning for Image Recognition](https://arxiv.org/abs/1512.03385). +/// +/// # Generics +/// - `T`: The underlying module to do a skip connection around. +/// - `U`: The underlying residual module +/// +/// # Examples +/// ```rust +/// # use dfdx::prelude::*; +/// # let dev: Cpu = Default::default(); +/// type Model = GeneralizedResidual; +/// let model = dev.build_module::(); +/// let x = dev.tensor([-2.0, -1.0, 0.0, 1.0, 2.0]); +/// let y = model.forward(x); +/// assert_eq!(y.array(), [4.0, 1.0, 0.0, 2.0, 6.0]); +/// ``` #[derive( Default, Clone, Debug, ResetParams, ZeroGrads, UpdateParams, LoadSafeTensors, SaveSafeTensors, )] diff --git a/dfdx-nn/src/layer_norm1d.rs b/dfdx-nn/src/layer_norm1d.rs index 9a8f497c..b8f5e221 100644 --- a/dfdx-nn/src/layer_norm1d.rs +++ b/dfdx-nn/src/layer_norm1d.rs @@ -1,10 +1,29 @@ use crate::*; use dfdx::prelude::*; +/// Implements layer normalization as described in [Layer Normalization](https://arxiv.org/abs/1607.06450). +/// +/// This calls [normalize()] on the last axis of the input to normalize to 0 mean and unit std dev, and then does an element-wise +/// affine transform using learnable parameters [Self::gamma] and [Self::beta]. +/// +/// [Self::epsilon] is passed to [normalize()] and added to the variance to ensure big enough numbers. It defaults to `1e-5`. +/// +/// Generics: +/// - `M` The size of the affine transform tensors. +/// +/// # Examples +/// ```rust +/// # use dfdx::prelude::*; +/// # let dev: Cpu = Default::default(); +/// type Model = LayerNorm1D<5>; +/// let model = dev.build_module::(); +/// let _: Tensor, f32, _> = model.forward(dev.zeros::>()); +/// ``` #[derive(Default, Clone, Copy, Debug)] #[repr(transparent)] pub struct LayerNorm1DConfig(pub M); +/// Compile time sugar alias around [LayerNorm1DConfig] pub type LayerNorm1DConstConfig = LayerNorm1DConfig>; impl> crate::BuildOnDevice for LayerNorm1DConfig { @@ -18,6 +37,7 @@ impl> crate::BuildOnDevice for LayerNorm1DC } } +/// See [LayerNorm1DConfig] #[derive(Clone, Debug, UpdateParams, ZeroGrads, SaveSafeTensors, LoadSafeTensors)] pub struct LayerNorm1D> { #[param] diff --git a/dfdx-nn/src/leaky_relu.rs b/dfdx-nn/src/leaky_relu.rs index cb8a86f7..cbc8c866 100644 --- a/dfdx-nn/src/leaky_relu.rs +++ b/dfdx-nn/src/leaky_relu.rs @@ -3,6 +3,7 @@ use dfdx::{ tensor_ops::TryPReLU, }; +/// ReLU but maintains a small gradient if the input values are negative. #[derive(Debug, Clone, Copy, crate::CustomModule)] pub struct LeakyReLU(pub f64); diff --git a/dfdx-nn/src/linear.rs b/dfdx-nn/src/linear.rs index 66ec7bc8..a601273a 100644 --- a/dfdx-nn/src/linear.rs +++ b/dfdx-nn/src/linear.rs @@ -2,6 +2,24 @@ use crate::{Bias1DConfig, MatMulConfig, Sequential}; use dfdx::shapes::{Const, Dim}; +/// A linear transformation of the form `weight * x + bias`, where `weight` is a matrix, `x` is a vector or matrix, +/// and `bias` is a vector. +/// +/// Generics: +/// - `I` The "input" size of vectors & matrices. +/// - `O` The "output" size of vectors & matrices. +/// +/// Example: +/// ```rust +/// # use dfdx::prelude::*; +/// # let dev: Cpu = Default::default(); +/// let arch: LinearConstConfig<5, 2> = Default::default(); +/// let model = dev.build_module_ext::(arch); +/// // single item forward +/// let _: Tensor, f32, _> = model.forward(dev.zeros::>()); +/// // batched forward +/// let _: Tensor, f32, _> = model.forward(dev.zeros::>()); +/// ``` #[derive(Default, Debug, Clone, Copy, Sequential)] #[built(Linear)] pub struct LinearConfig { @@ -9,6 +27,7 @@ pub struct LinearConfig { pub bias: Bias1DConfig, } +/// Compile time sugar alias around [LinearConfig]. pub type LinearConstConfig = LinearConfig, Const>; impl LinearConfig { diff --git a/dfdx-nn/src/ln.rs b/dfdx-nn/src/ln.rs index 4864857b..bb22ecf2 100644 --- a/dfdx-nn/src/ln.rs +++ b/dfdx-nn/src/ln.rs @@ -1,8 +1,8 @@ use dfdx::prelude::{Device, Dtype, Shape, Tape, Tensor}; +/// Calls [dfdx::tensor_ops::ln()]. #[derive(Default, Debug, Clone, Copy, crate::CustomModule)] pub struct Ln; - impl, T: Tape> crate::Module> for Ln { type Output = Tensor; type Error = D::Err; diff --git a/dfdx-nn/src/log_softmax.rs b/dfdx-nn/src/log_softmax.rs index 5ac83529..be6e20a6 100644 --- a/dfdx-nn/src/log_softmax.rs +++ b/dfdx-nn/src/log_softmax.rs @@ -1,8 +1,8 @@ use dfdx::prelude::{Device, Dtype, Shape, Tape, Tensor}; +/// Calls [dfdx::tensor_ops::log_softmax()]. #[derive(Default, Debug, Clone, Copy, crate::CustomModule)] pub struct LogSoftmax; - impl, T: Tape> crate::Module> for LogSoftmax { diff --git a/dfdx-nn/src/matmul.rs b/dfdx-nn/src/matmul.rs index 94d80833..9b003ee3 100644 --- a/dfdx-nn/src/matmul.rs +++ b/dfdx-nn/src/matmul.rs @@ -6,12 +6,27 @@ use dfdx::{ }; use rand_distr::Uniform; +/// Performs matrix multiplication of the form `x * W^T`, where `x` is the input, and `W` is the weight matrix. +/// `x` can be 1d, 2d, or 3d. +/// +/// Examples: +/// ```rust +/// # use dfdx::prelude::*; +/// # let dev: Cpu = Default::default(); +/// type Model = MatMulConfig<5, 2>; +/// let model = dev.build_module::(); +/// // single item forward +/// let _: Tensor, f32, _> = model.forward(dev.zeros::>()); +/// // batched forward +/// let _: Tensor, f32, _> = model.forward(dev.zeros::>()); +/// ``` #[derive(Clone, Copy, Debug, Default)] pub struct MatMulConfig { pub inp: I, pub out: O, } +/// Compile time sugar alias around [MatMulConfig]. pub type MatMulConstConfig = MatMulConfig, Const>; impl> BuildOnDevice for MatMulConfig { @@ -23,6 +38,7 @@ impl> BuildOnDevice for MatMulConfi } } +/// See [MatMulConfig]. #[derive(Clone, Debug, UpdateParams, ZeroGrads, SaveSafeTensors, LoadSafeTensors)] pub struct MatMul> { #[param] @@ -30,7 +46,6 @@ pub struct MatMul> { pub weight: Tensor<(O, I), Elem, Dev>, } -// NOTE: others can simply #[derive(ResetParams)] impl> ResetParams for MatMul where E: Dtype + num_traits::Float + rand_distr::uniform::SampleUniform, diff --git a/dfdx-nn/src/multi_head_attention.rs b/dfdx-nn/src/multi_head_attention.rs index e57c0fef..8c31159a 100644 --- a/dfdx-nn/src/multi_head_attention.rs +++ b/dfdx-nn/src/multi_head_attention.rs @@ -2,6 +2,20 @@ use crate::*; use dfdx::{shapes::*, tensor::*, tensor_ops::*}; use num_traits::Float; +/// A multi-head attention layer. +/// +/// Generics: +/// - `Embed`: The size of query vectors. +/// - `NumHeads` The number of heads to split query/key/value into. +/// - *Optional* `K`: The size of key vectors. Defaults to `Embed` +/// - *Optional* `V` The size of value vectors. Defaults to `Embed` +/// +/// **Pytorch equivalent**: `torch.nn.MultiheadAttention(Embed, NumHeads, batch_first=True)` +/// +/// Examples +/// - `MultiHeadAttention<8, 2>` is an attention layer with 2 heads and 8 token, key and value dims. +/// - `MultiHeadAttention<8, 2, 6, 4>` is an attention layer with the key and value dimension different +/// than the embed dimension #[derive(Default, Debug, Copy, Clone, CustomModule)] #[built(MultiHeadAttention)] pub struct MultiHeadAttentionConfig { diff --git a/dfdx-nn/src/pool_2d_avg.rs b/dfdx-nn/src/pool_2d_avg.rs index 29711981..b947dc5b 100644 --- a/dfdx-nn/src/pool_2d_avg.rs +++ b/dfdx-nn/src/pool_2d_avg.rs @@ -4,6 +4,14 @@ use dfdx::{ tensor_ops::TryPool2D, }; +/// Average pool with 2d kernel that operates on images (3d) and batches of images (4d). +/// Each patch reduces to the average of the values in the patch. +/// +/// Generics: +/// - `KernelSize`: The size of the kernel applied to both width and height of the images. +/// - `Stride`: How far to move the kernel each step. Defaults to `1` +/// - `Padding`: How much zero padding to add around the images. Defaults to `0`. +/// - `Dilation` How dilated the kernel should be. Defaults to `1`. #[derive(Debug, Default, Clone, CustomModule)] pub struct AvgPool2D< KernelSize: Dim, diff --git a/dfdx-nn/src/pool_2d_max.rs b/dfdx-nn/src/pool_2d_max.rs index 46ed5dc8..32c83d1e 100644 --- a/dfdx-nn/src/pool_2d_max.rs +++ b/dfdx-nn/src/pool_2d_max.rs @@ -4,6 +4,14 @@ use dfdx::{ tensor_ops::TryPool2D, }; +/// Max pool with 2d kernel that operates on images (3d) and batches of images (4d). +/// Each patch reduces to the maximum value in that patch. +/// +/// Generics: +/// - `KERNEL_SIZE`: The size of the kernel applied to both width and height of the images. +/// - `STRIDE`: How far to move the kernel each step. Defaults to `1` +/// - `PADDING`: How much zero padding to add around the images. Defaults to `0`. +/// - `DILATION` How dilated the kernel should be. Defaults to `1`. #[derive(Debug, Default, Clone, CustomModule)] pub struct MaxPool2D< KernelSize: Dim, diff --git a/dfdx-nn/src/pool_2d_min.rs b/dfdx-nn/src/pool_2d_min.rs index 41600918..4fd34776 100644 --- a/dfdx-nn/src/pool_2d_min.rs +++ b/dfdx-nn/src/pool_2d_min.rs @@ -4,6 +4,14 @@ use dfdx::{ tensor_ops::TryPool2D, }; +/// Minimum pool with 2d kernel that operates on images (3d) and batches of images (4d). +/// Each patch reduces to the minimum of the values in the patch. +/// +/// Generics: +/// - `KERNEL_SIZE`: The size of the kernel applied to both width and height of the images. +/// - `STRIDE`: How far to move the kernel each step. Defaults to `1` +/// - `PADDING`: How much zero padding to add around the images. Defaults to `0`. +/// - `DILATION` How dilated the kernel should be. Defaults to `1`. #[derive(Debug, Default, Clone, CustomModule)] pub struct MinPool2D< KernelSize: Dim, diff --git a/dfdx-nn/src/pool_global_avg.rs b/dfdx-nn/src/pool_global_avg.rs index ffaddfe8..6134c61d 100644 --- a/dfdx-nn/src/pool_global_avg.rs +++ b/dfdx-nn/src/pool_global_avg.rs @@ -1,5 +1,20 @@ use dfdx::prelude::{Device, Dim, Dtype, MeanTo, Tape, Tensor}; +/// Applies average pooling over an entire image, fully reducing the height and width +/// dimensions: +/// - Reduces 3d (C, H, W) to 1d (C, ) +/// - Reduces 4d (B, C, H, W) to 2d (B, C) +/// +/// **Pytorch equivalent**: `torch.nn.AdaptiveAvgPool2d(1)` followed by a flatten. +/// +/// Examples: +/// ```rust +/// # use dfdx::prelude::*; +/// # let dev: Cpu = Default::default(); +/// let m: AvgPoolGlobal = Default::default(); +/// let _: Tensor, f32, _> = m.forward(dev.zeros::>()); +/// let _: Tensor, f32, _> = m.forward(dev.zeros::>()); +/// ``` #[derive(Default, Debug, Clone, Copy, crate::CustomModule)] pub struct AvgPoolGlobal; diff --git a/dfdx-nn/src/pool_global_max.rs b/dfdx-nn/src/pool_global_max.rs index 96d28b5a..5db7425d 100644 --- a/dfdx-nn/src/pool_global_max.rs +++ b/dfdx-nn/src/pool_global_max.rs @@ -1,5 +1,20 @@ use dfdx::prelude::{Device, Dim, Dtype, MaxTo, Tape, Tensor}; +/// Applies max pooling over an entire image, fully reducing the height and width +/// dimensions: +/// - Reduces 3d (C, H, W) to 1d (C, ) +/// - Reduces 4d (B, C, H, W) to 2d (B, C) +/// +/// **Pytorch equivalent**: `torch.nn.AdaptiveMaxPool2d(1)` followed by a flatten. +/// +/// Examples: +/// ```rust +/// # use dfdx::prelude::*; +/// # let dev: Cpu = Default::default(); +/// let m: MaxPoolGlobal = Default::default(); +/// let _: Tensor, f32, _> = m.forward(dev.zeros::>()); +/// let _: Tensor, f32, _> = m.forward(dev.zeros::>()); +/// ``` #[derive(Default, Debug, Clone, Copy, crate::CustomModule)] pub struct MaxPoolGlobal; diff --git a/dfdx-nn/src/pool_global_min.rs b/dfdx-nn/src/pool_global_min.rs index 90eaffa2..1046d579 100644 --- a/dfdx-nn/src/pool_global_min.rs +++ b/dfdx-nn/src/pool_global_min.rs @@ -1,5 +1,20 @@ use dfdx::prelude::{Device, Dim, Dtype, MinTo, Tape, Tensor}; +/// Applies min pooling over an entire image, fully reducing the height and width +/// dimensions: +/// - Reduces 3d (C, H, W) to 1d (C, ) +/// - Reduces 4d (B, C, H, W) to 2d (B, C) +/// +/// **Pytorch equivalent**: `torch.nn.AdaptiveMinPool2d(1)` followed by a flatten. +/// +/// Examples: +/// ```rust +/// # use dfdx::prelude::*; +/// # let dev: Cpu = Default::default(); +/// let m: MinPoolGlobal = Default::default(); +/// let _: Tensor, f32, _> = m.forward(dev.zeros::>()); +/// let _: Tensor, f32, _> = m.forward(dev.zeros::>()); +/// ``` #[derive(Default, Debug, Clone, Copy, crate::CustomModule)] pub struct MinPoolGlobal; diff --git a/dfdx-nn/src/prelu.rs b/dfdx-nn/src/prelu.rs index bb468d27..85fa7324 100644 --- a/dfdx-nn/src/prelu.rs +++ b/dfdx-nn/src/prelu.rs @@ -6,6 +6,7 @@ use dfdx::{ use crate::*; +/// Calls [dfdx::tensor_ops::prelu()] with learnable value. #[derive(Debug, Clone, Copy)] pub struct PReLUConfig(pub f64); @@ -23,6 +24,7 @@ impl> BuildOnDevice for PReLUConfig { } } +/// See [PReLUConfig]. #[derive(Clone, Debug, UpdateParams, ZeroGrads, SaveSafeTensors, LoadSafeTensors)] pub struct PReLU> { #[param] diff --git a/dfdx-nn/src/prelu1d.rs b/dfdx-nn/src/prelu1d.rs index f7b528de..10dbf106 100644 --- a/dfdx-nn/src/prelu1d.rs +++ b/dfdx-nn/src/prelu1d.rs @@ -6,6 +6,7 @@ use dfdx::{ use crate::*; +/// Calls [prelu()] with learnable values along second dimension. #[derive(Debug, Clone, Copy)] pub struct PReLU1DConfig { pub a: f64, @@ -31,6 +32,7 @@ impl> BuildOnDevice for PReLU1DConfig { } } +/// See [PReLU1DConfig]. #[derive(Clone, Debug, UpdateParams, ZeroGrads, SaveSafeTensors, LoadSafeTensors)] pub struct PReLU1D> { #[param] diff --git a/dfdx-nn/src/relu.rs b/dfdx-nn/src/relu.rs index e486e020..917341a9 100644 --- a/dfdx-nn/src/relu.rs +++ b/dfdx-nn/src/relu.rs @@ -1,8 +1,8 @@ use dfdx::prelude::{Device, Dtype, Shape, Tape, Tensor}; +/// Calls [dfdx::tensor_ops::relu()]. #[derive(Default, Debug, Clone, Copy, crate::CustomModule)] pub struct ReLU; - impl, T: Tape> crate::Module> for ReLU { type Output = Tensor; type Error = D::Err; diff --git a/dfdx-nn/src/reshape.rs b/dfdx-nn/src/reshape.rs index 87db1930..46333311 100644 --- a/dfdx-nn/src/reshape.rs +++ b/dfdx-nn/src/reshape.rs @@ -5,6 +5,16 @@ use dfdx::{ tensor_ops::{Device, ReshapeTo}, }; +/// Reshapes input tensors to a configured shape. +/// +/// Example usage: +/// ```rust +/// # use dfdx::prelude::*; +/// # let dev: Cpu = Default::default(); +/// let model: Reshape> = Default::default(); +/// let x: Tensor, f32, _> = dev.sample_normal(); +/// let _: Tensor, f32, _> = model.forward(x); +/// ``` #[derive(Default, Debug, Clone, Copy, CustomModule)] pub struct Reshape(pub S); diff --git a/dfdx-nn/src/residual_add.rs b/dfdx-nn/src/residual_add.rs index 72570b21..a593a21a 100644 --- a/dfdx-nn/src/residual_add.rs +++ b/dfdx-nn/src/residual_add.rs @@ -7,6 +7,22 @@ use dfdx::{ use crate::Module; +/// A residual connection around `T`: `T(x) + x`, +/// as introduced in [Deep Residual Learning for Image Recognition](https://arxiv.org/abs/1512.03385). +/// +/// # Generics +/// - `T`: The underlying module to do a skip connection around. +/// +/// # Examples +/// ```rust +/// # use dfdx::prelude::*; +/// # let dev: Cpu = Default::default(); +/// type Model = Residual; +/// let model = dev.build_module::(); +/// let x = dev.tensor([-2.0, -1.0, 0.0, 1.0, 2.0]); +/// let y = model.forward(x); +/// assert_eq!(y.array(), [-2.0, -1.0, 0.0, 2.0, 4.0]); +/// ``` #[derive( Default, Clone, Debug, ResetParams, ZeroGrads, UpdateParams, SaveSafeTensors, LoadSafeTensors, )] diff --git a/dfdx-nn/src/sigmoid.rs b/dfdx-nn/src/sigmoid.rs index a8b1605c..35ee38a6 100644 --- a/dfdx-nn/src/sigmoid.rs +++ b/dfdx-nn/src/sigmoid.rs @@ -1,8 +1,8 @@ use dfdx::prelude::{Device, Dtype, Shape, Tape, Tensor}; +/// Calls [dfdx::tensor_ops::sigmoid()]. #[derive(Default, Debug, Clone, Copy, crate::CustomModule)] pub struct Sigmoid; - impl, T: Tape> crate::Module> for Sigmoid { diff --git a/dfdx-nn/src/sin.rs b/dfdx-nn/src/sin.rs index 197b3c14..988eb186 100644 --- a/dfdx-nn/src/sin.rs +++ b/dfdx-nn/src/sin.rs @@ -1,8 +1,8 @@ use dfdx::prelude::{Device, Dtype, Shape, Tape, Tensor}; +/// Calls [dfdx::tensor_ops::sin()]. #[derive(Default, Debug, Clone, Copy, crate::CustomModule)] pub struct Sin; - impl, T: Tape> crate::Module> for Sin { type Output = Tensor; type Error = D::Err; diff --git a/dfdx-nn/src/softmax.rs b/dfdx-nn/src/softmax.rs index db0f9e18..3e6c3830 100644 --- a/dfdx-nn/src/softmax.rs +++ b/dfdx-nn/src/softmax.rs @@ -1,8 +1,8 @@ use dfdx::prelude::{Device, Dtype, Shape, Tape, Tensor}; +/// Calls [dfdx::tensor_ops::softmax()] on the last axis of the input. #[derive(Default, Debug, Clone, Copy, crate::CustomModule)] pub struct Softmax; - impl, T: Tape> crate::Module> for Softmax { diff --git a/dfdx-nn/src/split_into.rs b/dfdx-nn/src/split_into.rs index 41407111..2486c27b 100644 --- a/dfdx-nn/src/split_into.rs +++ b/dfdx-nn/src/split_into.rs @@ -2,6 +2,26 @@ use crate::*; use dfdx::{shapes::Dtype, tensor::WithEmptyTape, tensor_ops::Device}; +/// Splits input into multiple heads. `T` should be a tuple, +/// where every element of the tuple accepts the same input type. +/// +/// This provides a utility for multi headed structures where +/// the tape needs to be moved around a number of times. +/// +/// Each head's operations will be stored in its output's tape, while the operations stored in the +/// input tape will be saved in the first output's tape. +/// +/// # Generics +/// - `T` the module to split the input into. +/// +/// # Examples +/// ```rust +/// # use dfdx::prelude::*; +/// # let dev: Cpu = Default::default(); +/// type Model = SplitInto<(Linear<5, 3>, Linear<5, 7>)>; +/// let model = dev.build_module::(); +/// let _: (Tensor, f32, _>, Tensor, f32, _>) = model.forward(dev.zeros::>()); +/// ``` #[derive( Debug, Default, Clone, ResetParams, ZeroGrads, UpdateParams, LoadSafeTensors, SaveSafeTensors, )] diff --git a/dfdx-nn/src/sqrt.rs b/dfdx-nn/src/sqrt.rs index 1ee11879..6890a0af 100644 --- a/dfdx-nn/src/sqrt.rs +++ b/dfdx-nn/src/sqrt.rs @@ -1,5 +1,6 @@ use dfdx::prelude::{Device, Dtype, Shape, Tape, Tensor}; +/// Calls [dfdx::tensor_ops::sqrt()]. #[derive(Default, Debug, Clone, Copy, crate::CustomModule)] pub struct Sqrt; diff --git a/dfdx-nn/src/square.rs b/dfdx-nn/src/square.rs index 44af5534..38485319 100644 --- a/dfdx-nn/src/square.rs +++ b/dfdx-nn/src/square.rs @@ -1,8 +1,8 @@ use dfdx::prelude::{Device, Dtype, Shape, Tape, Tensor}; +/// Calls [dfdx::tensor_ops::square()]. #[derive(Default, Debug, Clone, Copy, crate::CustomModule)] pub struct Square; - impl, T: Tape> crate::Module> for Square { type Output = Tensor; type Error = D::Err; diff --git a/dfdx-nn/src/tanh.rs b/dfdx-nn/src/tanh.rs index da603cbb..0da70869 100644 --- a/dfdx-nn/src/tanh.rs +++ b/dfdx-nn/src/tanh.rs @@ -1,8 +1,8 @@ use dfdx::prelude::{Device, Dtype, Shape, Tape, Tensor}; +/// Calls [dfdx::tensor_ops::tanh()]. #[derive(Default, Debug, Clone, Copy, crate::CustomModule)] pub struct Tanh; - impl, T: Tape> crate::Module> for Tanh { type Output = Tensor; type Error = D::Err; diff --git a/dfdx-nn/src/transformer.rs b/dfdx-nn/src/transformer.rs index a643f79a..5afb0235 100644 --- a/dfdx-nn/src/transformer.rs +++ b/dfdx-nn/src/transformer.rs @@ -15,6 +15,19 @@ pub struct FeedForwardConfig { pub l2: LinearConfig, } +/// A single transformer encoder block +/// +/// Generics +/// - `Model`: The size of query/key/value tensors. Given to [MultiHeadAttention]. +/// - `NumHeads`: The number of heads in [MultiHeadAttention]. +/// - `F`: The size of the hidden layer in the feedforward network. +/// +/// **Pytorch equivalent**: +/// ```python +/// encoder = torch.nn.TransformerEncoderLayer( +/// Model, NumHeads, dim_feedforward=F, batch_first=True, dropout=0.0 +/// ) +/// ``` #[derive(Clone, Debug, Sequential)] #[built(EncoderBlock)] pub struct EncoderBlockConfig { @@ -41,6 +54,20 @@ impl EncoderBlockConfig { } } +/// A transformer decoder block. Different than the normal transformer block +/// as this self attention accepts an additional sequence from the encoder. +/// +/// Generics +/// - `Model`: The size of query/key/value tensors. Given to [MultiHeadAttention]. +/// - `NumHeads`: The number of heads in [MultiHeadAttention]. +/// - `F`: The size of the hidden layer in the feedforward network. +/// +/// **Pytorch equivalent**: +/// ```python +/// decoder = torch.nn.TransformerDecoderLayer( +/// Model, NumHeads, dim_feedforward=F, batch_first=True, dropout=0.0 +/// ) +/// ``` #[derive(Clone, Debug, CustomModule)] #[built(DecoderBlock)] pub struct DecoderBlockConfig { @@ -106,6 +133,27 @@ where } } +/// Transformer architecture as described in +/// [Attention is all you need](https://arxiv.org/abs/1706.03762). +/// +/// This is comprised of a [EncoderBlockConfig] and a [DecoderBlockConfig]. +/// +/// Generics: +/// - `Model`: Size of the input features to the encoder/decoder. +/// - `NumHeads`: Number of heads for [MultiHeadAttention]. +/// - `F`: Feedforward hidden dimension for both encoder/decoder +/// +/// **Pytorch equivalent**: +/// ```python +/// torch.nn.Transformer( +/// d_model=Model, +/// nhead=NumHeads, +/// num_encoder_layers=cfg.encoder.len(), +/// num_decoder_layers=cfg.decoder.len(), +/// dim_feedforward=F, +/// batch_first=True, +/// ) +/// ``` #[derive(Clone, Debug, CustomModule)] #[built(Transformer)] pub struct TransformerConfig {