Skip to content

Commit

Permalink
Adding documentation to dtypes module and amp
Browse files Browse the repository at this point in the history
  • Loading branch information
coreylowman committed Jul 27, 2023
1 parent 0b49672 commit a49f69d
Show file tree
Hide file tree
Showing 8 changed files with 30 additions and 10 deletions.
5 changes: 5 additions & 0 deletions src/dtypes/amp.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
use rand::{distributions::Distribution, Rng};

/// Wrapper type around the storage type. Use like `AMP<f16>` or `AMP<bf16>`.
///
/// This causes some tensor operations to cast the type to a higher precision
/// and then back. For example calling sum on a `AMP<f16>` tensor will cast it to
/// `f32`, sum it, and then cast it back to `f16`.
#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
#[repr(transparent)]
pub struct AMP<F>(pub F);
Expand Down
12 changes: 12 additions & 0 deletions src/dtypes/mod.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,24 @@
//! Module for data type related traits and structs. Contains things like [Unit], [Dtype], and [AMP].
//!
//! When the `f16` feature is enabled, this exports the [f16] type.

Check warning on line 3 in src/dtypes/mod.rs

View workflow job for this annotation

GitHub Actions / cargo-check

unresolved link to `f16`

Check warning on line 3 in src/dtypes/mod.rs

View workflow job for this annotation

GitHub Actions / cargo-check

unresolved link to `f16`
//!
//! # AMP
//!
//! [AMP](https://pytorch.org/docs/stable/amp.html) is a technique for mixed precision training.
//! This is a data type in dfdx, you can use it like any normal dtype like [`AMP<f16>`] or [`AMP<bf16>`].

mod amp;

pub use amp::AMP;

#[cfg(feature = "f16")]
pub use half::f16;

/// Represents a type where all 0 bits is a valid pattern.
#[cfg(not(feature = "cuda"))]
pub trait SafeZeros {}

/// Represents a type where all 0 bits is a valid pattern.
#[cfg(feature = "cuda")]
pub trait SafeZeros: cudarc::driver::ValidAsZeroBits + cudarc::driver::DeviceRepr {}

Expand Down Expand Up @@ -100,6 +111,7 @@ pub trait HasDtype {
type Dtype: Dtype;
}

/// Marker trait for types that are **not** [AMP].
pub trait NotMixedPrecision {}
impl NotMixedPrecision for f32 {}
impl NotMixedPrecision for f64 {}
Expand Down
7 changes: 4 additions & 3 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@
//!
//! # Shapes & Tensors
//!
//! *See [shapes] and [tensor] for more information.*
//! *See [dtypes], [shapes], and [tensor] for more information.*
//!
//! At its core a [`tensor::Tensor`] is just a nd-array. Just like
//! rust arrays there are two parts:
//! 1. Shape
//! 2. Dtype
//! 1. Shape ([shapes])
//! 2. Dtype ([dtypes])
//!
//! dfdx represents shapes as **tuples** of dimensions ([`shapes::Dim`]),
//! where a dimension can either be known at:
Expand All @@ -31,6 +31,7 @@
//! - `(usize,)` - 1d shape with a runtime known dimension
//! - `(usize, Const<5>)` - 2d shape with both types of dimensions
//! - `(Const<3>, usize, Const<5>)` - 3d shape!
//! - `Rank3<3, 5, 7>` - Equivalent to `(Const<3>, Const<5>, Const<7>)`
//!
//! Here are some comparisons between representing nd arrays in rust vs dfdx:
//!
Expand Down
2 changes: 1 addition & 1 deletion src/tensor_ops/adam/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use crate::{

use super::WeightDecay;

/// Configuration of hyperparameters for [Adam].
/// Configuration of hyperparameters for [crate::optim::Adam].
///
/// Changing all default parameters:
/// ```rust
Expand Down
4 changes: 2 additions & 2 deletions src/tensor_ops/optim.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,13 @@ pub(super) fn weight_decay_to_cuda(wd: Option<WeightDecay>) -> (WeightDecayType,
}
}

/// Momentum used for [super::Sgd] and others
/// Momentum used for [crate::optim::Sgd] and others
#[derive(Debug, Clone, Copy)]
pub enum Momentum {
/// Momentum that is applied to the velocity of a parameter directly.
Classic(f64),

/// Momentum that is applied to both velocity and gradients. See [super::Sgd] nesterov paper for more.
/// Momentum that is applied to both velocity and gradients. See [crate::optim::Sgd] nesterov paper for more.
Nesterov(f64),
}

Expand Down
4 changes: 2 additions & 2 deletions src/tensor_ops/prelu.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ use super::{cmp::*, BroadcastTo, ChooseFrom, Device, TryMul};
/// [Parametric Rectified Linear Unit (PReLU)](https://pytorch.org/docs/stable/generated/torch.nn.PReLU.html). `max(0, lhs) + rhs*min(0, lhs)`
///
/// In other words, for each element i:
/// - if lhs[i] < 0, use `lhs[i] * rhs[i]`
/// - if lhs[i] >= 0, use `lhs[i]`
/// - if `lhs[i] < 0`, use `lhs[i] * rhs[i]`
/// - if `lhs[i] >= 0`, use `lhs[i]`
///
///
/// Examples:
Expand Down
3 changes: 2 additions & 1 deletion src/tensor_ops/rmsprop/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use crate::{

use super::WeightDecay;

/// Configuration of hyperparameters for [RMSprop].
/// Configuration of hyperparameters for [crate::optim::RMSprop].
#[derive(Debug, Clone, Copy)]
pub struct RMSpropConfig {
/// Learning rate. Defaults to `1e-2`.
Expand Down Expand Up @@ -59,6 +59,7 @@ pub trait RMSpropKernel<E: Dtype>: Storage<E> {
}

impl RMSpropConfig {
/// Update a single tensor using RMSprop.
pub fn try_update<S: Shape, E: Dtype, D: RMSpropKernel<E>>(
&self,
param: &mut Tensor<S, E, D>,
Expand Down
3 changes: 2 additions & 1 deletion src/tensor_ops/sgd/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use crate::{

use super::optim::{Momentum, WeightDecay};

/// Configuration of hyperparameters for [Sgd].
/// Configuration of hyperparameters for [crate::optim::Sgd].
///
/// Using different learning rate:
/// ```rust
Expand Down Expand Up @@ -94,6 +94,7 @@ pub trait SgdKernel<E: Dtype>: Storage<E> {
}

impl SgdConfig {
/// Updates a single tensor using SGD.
pub fn try_update<S: Shape, E: Dtype, D: SgdKernel<E>>(
&self,
param: &mut Tensor<S, E, D>,
Expand Down

0 comments on commit a49f69d

Please sign in to comment.