diff --git a/src/dtypes/amp.rs b/src/dtypes/amp.rs index f5fd0cc5d..b87f3f953 100644 --- a/src/dtypes/amp.rs +++ b/src/dtypes/amp.rs @@ -1,5 +1,10 @@ use rand::{distributions::Distribution, Rng}; +/// Wrapper type around the storage type. Use like `AMP` or `AMP`. +/// +/// This causes some tensor operations to cast the type to a higher precision +/// and then back. For example calling sum on a `AMP` tensor will cast it to +/// `f32`, sum it, and then cast it back to `f16`. #[derive(Debug, Default, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] #[repr(transparent)] pub struct AMP(pub F); diff --git a/src/dtypes/mod.rs b/src/dtypes/mod.rs index 3a17d6032..f259263b6 100644 --- a/src/dtypes/mod.rs +++ b/src/dtypes/mod.rs @@ -1,3 +1,12 @@ +//! Module for data type related traits and structs. Contains things like [Unit], [Dtype], and [AMP]. +//! +//! When the `f16` feature is enabled, this exports the [f16] type. +//! +//! # AMP +//! +//! [AMP](https://pytorch.org/docs/stable/amp.html) is a technique for mixed precision training. +//! This is a data type in dfdx, you can use it like any normal dtype like [`AMP`] or [`AMP`]. + mod amp; pub use amp::AMP; @@ -5,9 +14,11 @@ pub use amp::AMP; #[cfg(feature = "f16")] pub use half::f16; +/// Represents a type where all 0 bits is a valid pattern. #[cfg(not(feature = "cuda"))] pub trait SafeZeros {} +/// Represents a type where all 0 bits is a valid pattern. #[cfg(feature = "cuda")] pub trait SafeZeros: cudarc::driver::ValidAsZeroBits + cudarc::driver::DeviceRepr {} @@ -100,6 +111,7 @@ pub trait HasDtype { type Dtype: Dtype; } +/// Marker trait for types that are **not** [AMP]. pub trait NotMixedPrecision {} impl NotMixedPrecision for f32 {} impl NotMixedPrecision for f64 {} diff --git a/src/lib.rs b/src/lib.rs index 026502a13..c5b5e5e6b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -13,12 +13,12 @@ //! //! # Shapes & Tensors //! -//! *See [shapes] and [tensor] for more information.* +//! *See [dtypes], [shapes], and [tensor] for more information.* //! //! At its core a [`tensor::Tensor`] is just a nd-array. Just like //! rust arrays there are two parts: -//! 1. Shape -//! 2. Dtype +//! 1. Shape ([shapes]) +//! 2. Dtype ([dtypes]) //! //! dfdx represents shapes as **tuples** of dimensions ([`shapes::Dim`]), //! where a dimension can either be known at: @@ -31,6 +31,7 @@ //! - `(usize,)` - 1d shape with a runtime known dimension //! - `(usize, Const<5>)` - 2d shape with both types of dimensions //! - `(Const<3>, usize, Const<5>)` - 3d shape! +//! - `Rank3<3, 5, 7>` - Equivalent to `(Const<3>, Const<5>, Const<7>)` //! //! Here are some comparisons between representing nd arrays in rust vs dfdx: //! diff --git a/src/tensor_ops/adam/mod.rs b/src/tensor_ops/adam/mod.rs index 0188f8aeb..1e95777af 100644 --- a/src/tensor_ops/adam/mod.rs +++ b/src/tensor_ops/adam/mod.rs @@ -10,7 +10,7 @@ use crate::{ use super::WeightDecay; -/// Configuration of hyperparameters for [Adam]. +/// Configuration of hyperparameters for [crate::optim::Adam]. /// /// Changing all default parameters: /// ```rust diff --git a/src/tensor_ops/optim.rs b/src/tensor_ops/optim.rs index 1c905aa35..469c22c14 100644 --- a/src/tensor_ops/optim.rs +++ b/src/tensor_ops/optim.rs @@ -28,13 +28,13 @@ pub(super) fn weight_decay_to_cuda(wd: Option) -> (WeightDecayType, } } -/// Momentum used for [super::Sgd] and others +/// Momentum used for [crate::optim::Sgd] and others #[derive(Debug, Clone, Copy)] pub enum Momentum { /// Momentum that is applied to the velocity of a parameter directly. Classic(f64), - /// Momentum that is applied to both velocity and gradients. See [super::Sgd] nesterov paper for more. + /// Momentum that is applied to both velocity and gradients. See [crate::optim::Sgd] nesterov paper for more. Nesterov(f64), } diff --git a/src/tensor_ops/prelu.rs b/src/tensor_ops/prelu.rs index 485b9766b..b6cc38ac9 100644 --- a/src/tensor_ops/prelu.rs +++ b/src/tensor_ops/prelu.rs @@ -5,8 +5,8 @@ use super::{cmp::*, BroadcastTo, ChooseFrom, Device, TryMul}; /// [Parametric Rectified Linear Unit (PReLU)](https://pytorch.org/docs/stable/generated/torch.nn.PReLU.html). `max(0, lhs) + rhs*min(0, lhs)` /// /// In other words, for each element i: -/// - if lhs[i] < 0, use `lhs[i] * rhs[i]` -/// - if lhs[i] >= 0, use `lhs[i]` +/// - if `lhs[i] < 0`, use `lhs[i] * rhs[i]` +/// - if `lhs[i] >= 0`, use `lhs[i]` /// /// /// Examples: diff --git a/src/tensor_ops/rmsprop/mod.rs b/src/tensor_ops/rmsprop/mod.rs index b4095a2a7..1899cd5e7 100644 --- a/src/tensor_ops/rmsprop/mod.rs +++ b/src/tensor_ops/rmsprop/mod.rs @@ -10,7 +10,7 @@ use crate::{ use super::WeightDecay; -/// Configuration of hyperparameters for [RMSprop]. +/// Configuration of hyperparameters for [crate::optim::RMSprop]. #[derive(Debug, Clone, Copy)] pub struct RMSpropConfig { /// Learning rate. Defaults to `1e-2`. @@ -59,6 +59,7 @@ pub trait RMSpropKernel: Storage { } impl RMSpropConfig { + /// Update a single tensor using RMSprop. pub fn try_update>( &self, param: &mut Tensor, diff --git a/src/tensor_ops/sgd/mod.rs b/src/tensor_ops/sgd/mod.rs index be27f1393..112f248a7 100644 --- a/src/tensor_ops/sgd/mod.rs +++ b/src/tensor_ops/sgd/mod.rs @@ -10,7 +10,7 @@ use crate::{ use super::optim::{Momentum, WeightDecay}; -/// Configuration of hyperparameters for [Sgd]. +/// Configuration of hyperparameters for [crate::optim::Sgd]. /// /// Using different learning rate: /// ```rust @@ -94,6 +94,7 @@ pub trait SgdKernel: Storage { } impl SgdConfig { + /// Updates a single tensor using SGD. pub fn try_update>( &self, param: &mut Tensor,