Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding documentation to dtypes module and amp #834

Merged
merged 1 commit into from
Jul 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/dtypes/amp.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
use rand::{distributions::Distribution, Rng};

/// Wrapper type around the storage type. Use like `AMP<f16>` or `AMP<bf16>`.
///
/// This causes some tensor operations to cast the type to a higher precision
/// and then back. For example calling sum on a `AMP<f16>` tensor will cast it to
/// `f32`, sum it, and then cast it back to `f16`.
#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
#[repr(transparent)]
pub struct AMP<F>(pub F);
Expand Down
12 changes: 12 additions & 0 deletions src/dtypes/mod.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,24 @@
//! Module for data type related traits and structs. Contains things like [Unit], [Dtype], and [AMP].
//!
//! When the `f16` feature is enabled, this exports the [f16] type.

Check warning on line 3 in src/dtypes/mod.rs

View workflow job for this annotation

GitHub Actions / cargo-check

unresolved link to `f16`

Check warning on line 3 in src/dtypes/mod.rs

View workflow job for this annotation

GitHub Actions / cargo-check

unresolved link to `f16`
//!
//! # AMP
//!
//! [AMP](https://pytorch.org/docs/stable/amp.html) is a technique for mixed precision training.
//! This is a data type in dfdx, you can use it like any normal dtype like [`AMP<f16>`] or [`AMP<bf16>`].

mod amp;

pub use amp::AMP;

#[cfg(feature = "f16")]
pub use half::f16;

/// Represents a type where all 0 bits is a valid pattern.
#[cfg(not(feature = "cuda"))]
pub trait SafeZeros {}

/// Represents a type where all 0 bits is a valid pattern.
#[cfg(feature = "cuda")]
pub trait SafeZeros: cudarc::driver::ValidAsZeroBits + cudarc::driver::DeviceRepr {}

Expand Down Expand Up @@ -100,6 +111,7 @@
type Dtype: Dtype;
}

/// Marker trait for types that are **not** [AMP].
pub trait NotMixedPrecision {}
impl NotMixedPrecision for f32 {}
impl NotMixedPrecision for f64 {}
Expand Down
7 changes: 4 additions & 3 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@
//!
//! # Shapes & Tensors
//!
//! *See [shapes] and [tensor] for more information.*
//! *See [dtypes], [shapes], and [tensor] for more information.*
//!
//! At its core a [`tensor::Tensor`] is just a nd-array. Just like
//! rust arrays there are two parts:
//! 1. Shape
//! 2. Dtype
//! 1. Shape ([shapes])
//! 2. Dtype ([dtypes])
//!
//! dfdx represents shapes as **tuples** of dimensions ([`shapes::Dim`]),
//! where a dimension can either be known at:
Expand All @@ -31,6 +31,7 @@
//! - `(usize,)` - 1d shape with a runtime known dimension
//! - `(usize, Const<5>)` - 2d shape with both types of dimensions
//! - `(Const<3>, usize, Const<5>)` - 3d shape!
//! - `Rank3<3, 5, 7>` - Equivalent to `(Const<3>, Const<5>, Const<7>)`
//!
//! Here are some comparisons between representing nd arrays in rust vs dfdx:
//!
Expand Down Expand Up @@ -58,7 +59,7 @@
//! There are two options for this currently, with more planned to be added in the future:
//!
//! 1. [tensor::Cpu] - for tensors stored on the heap
//! 2. [tensor::Cuda] - for tensors stored in GPU memory

Check warning on line 62 in src/lib.rs

View workflow job for this annotation

GitHub Actions / cargo-check

unresolved link to `tensor::Cuda`

Check warning on line 62 in src/lib.rs

View workflow job for this annotation

GitHub Actions / cargo-check

unresolved link to `tensor::Cuda`
//!
//! Both devices implement [Default], you can also create them with a certain seed
//! and ordinal.
Expand All @@ -84,8 +85,8 @@
//! | Unary Operations | `a.sqrt()` | `a.sqrt()` | `a.sqrt()` |
//! | Binary Operations | `a + b` | `a + b` | `a + b` |
//! | gemm/gemv | [tensor_ops::matmul] | `a @ b` | `a @ b` |
//! | 2d Convolution | [tensor_ops::TryConv2D] | - | `torch.conv2d` |

Check warning on line 88 in src/lib.rs

View workflow job for this annotation

GitHub Actions / cargo-check

unresolved link to `tensor_ops::TryConv2D`

Check warning on line 88 in src/lib.rs

View workflow job for this annotation

GitHub Actions / cargo-check

unresolved link to `tensor_ops::TryConv2D`
//! | 2d Transposed Convolution | [tensor_ops::TryConvTrans2D] | - | `torch.conv_transpose2d` |

Check warning on line 89 in src/lib.rs

View workflow job for this annotation

GitHub Actions / cargo-check

unresolved link to `tensor_ops::TryConvTrans2D`

Check warning on line 89 in src/lib.rs

View workflow job for this annotation

GitHub Actions / cargo-check

unresolved link to `tensor_ops::TryConvTrans2D`
//! | Slicing | [tensor_ops::slice] | `a[...]` | `a[...]` |
//! | Select | [tensor_ops::SelectTo] | `a[...]` | `torch.select` |
//! | Gather | [tensor_ops::GatherTo] | `np.take` | `torch.gather` |
Expand Down
2 changes: 1 addition & 1 deletion src/tensor_ops/adam/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use crate::{

use super::WeightDecay;

/// Configuration of hyperparameters for [Adam].
/// Configuration of hyperparameters for [crate::optim::Adam].
///
/// Changing all default parameters:
/// ```rust
Expand Down
4 changes: 2 additions & 2 deletions src/tensor_ops/optim.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,13 @@ pub(super) fn weight_decay_to_cuda(wd: Option<WeightDecay>) -> (WeightDecayType,
}
}

/// Momentum used for [super::Sgd] and others
/// Momentum used for [crate::optim::Sgd] and others
#[derive(Debug, Clone, Copy)]
pub enum Momentum {
/// Momentum that is applied to the velocity of a parameter directly.
Classic(f64),

/// Momentum that is applied to both velocity and gradients. See [super::Sgd] nesterov paper for more.
/// Momentum that is applied to both velocity and gradients. See [crate::optim::Sgd] nesterov paper for more.
Nesterov(f64),
}

Expand Down
4 changes: 2 additions & 2 deletions src/tensor_ops/prelu.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ use super::{cmp::*, BroadcastTo, ChooseFrom, Device, TryMul};
/// [Parametric Rectified Linear Unit (PReLU)](https://pytorch.org/docs/stable/generated/torch.nn.PReLU.html). `max(0, lhs) + rhs*min(0, lhs)`
///
/// In other words, for each element i:
/// - if lhs[i] < 0, use `lhs[i] * rhs[i]`
/// - if lhs[i] >= 0, use `lhs[i]`
/// - if `lhs[i] < 0`, use `lhs[i] * rhs[i]`
/// - if `lhs[i] >= 0`, use `lhs[i]`
///
///
/// Examples:
Expand Down
3 changes: 2 additions & 1 deletion src/tensor_ops/rmsprop/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use crate::{

use super::WeightDecay;

/// Configuration of hyperparameters for [RMSprop].
/// Configuration of hyperparameters for [crate::optim::RMSprop].
#[derive(Debug, Clone, Copy)]
pub struct RMSpropConfig {
/// Learning rate. Defaults to `1e-2`.
Expand Down Expand Up @@ -59,6 +59,7 @@ pub trait RMSpropKernel<E: Dtype>: Storage<E> {
}

impl RMSpropConfig {
/// Update a single tensor using RMSprop.
pub fn try_update<S: Shape, E: Dtype, D: RMSpropKernel<E>>(
&self,
param: &mut Tensor<S, E, D>,
Expand Down
3 changes: 2 additions & 1 deletion src/tensor_ops/sgd/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use crate::{

use super::optim::{Momentum, WeightDecay};

/// Configuration of hyperparameters for [Sgd].
/// Configuration of hyperparameters for [crate::optim::Sgd].
///
/// Using different learning rate:
/// ```rust
Expand Down Expand Up @@ -94,6 +94,7 @@ pub trait SgdKernel<E: Dtype>: Storage<E> {
}

impl SgdConfig {
/// Updates a single tensor using SGD.
pub fn try_update<S: Shape, E: Dtype, D: SgdKernel<E>>(
&self,
param: &mut Tensor<S, E, D>,
Expand Down
Loading