Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Prodigy optimizer #895

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion dfdx-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ num-traits = { workspace = true }
safetensors = { workspace = true, optional = true }
memmap2 = { workspace = true, optional = true }
half = { version = "2.3.1", optional = true, features = ["num-traits", "rand_distr"] }
gemm = { version = "0.16.14", default-features = false, optional = true, features = ["rayon"] }
gemm = { version = "0.17.1", default-features = false, optional = true, features = ["rayon"] }
rayon = { version = "1.7.0", optional = true }
libm = { workspace = true }
wgpu = { version = "0.18.0", features = ["glsl", "spirv"], optional = true }
Expand Down
1 change: 1 addition & 0 deletions dfdx-core/src/data/collate.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::{mem::MaybeUninit, vec::Vec};

Check warning on line 1 in dfdx-core/src/data/collate.rs

View workflow job for this annotation

GitHub Actions / cargo-test-nightly

the item `Vec` is imported redundantly

Check warning on line 1 in dfdx-core/src/data/collate.rs

View workflow job for this annotation

GitHub Actions / cargo-test-nightly

the item `Vec` is imported redundantly

Check warning on line 1 in dfdx-core/src/data/collate.rs

View workflow job for this annotation

GitHub Actions / cargo-test-nightly

the item `Vec` is imported redundantly

Check warning on line 1 in dfdx-core/src/data/collate.rs

View workflow job for this annotation

GitHub Actions / cargo-test-nightly

the item `Vec` is imported redundantly

Check warning on line 1 in dfdx-core/src/data/collate.rs

View workflow job for this annotation

GitHub Actions / cargo-test-nightly

the item `Vec` is imported redundantly

/// Collates `Self` into some other type.
/// Generally similar to an unzip method;
Expand Down Expand Up @@ -55,6 +55,7 @@
impl<'a, A, B> Collate for Vec<&'a (A, B)> {
type Collated = (Vec<&'a A>, Vec<&'a B>);
fn collated(self) -> Self::Collated {
#[allow(clippy::map_identity)]
self.into_iter().map(|(a, b)| (a, b)).unzip()
}
}
Expand Down
38 changes: 0 additions & 38 deletions dfdx-core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
//! The following sections provide some high level core concepts & exmaples, and
//! there is more detailed documentation in each of dfdx's submodules.
//!
//! See [feature_flags] for details on feature flags.

Check warning on line 12 in dfdx-core/src/lib.rs

View workflow job for this annotation

GitHub Actions / cargo-check

unresolved link to `feature_flags`
//!
//! # Shapes & Tensors
//!
Expand Down Expand Up @@ -59,7 +59,7 @@
//! There are two options for this currently, with more planned to be added in the future:
//!
//! 1. [tensor::Cpu] - for tensors stored on the heap
//! 2. [tensor::Cuda] - for tensors stored in GPU memory

Check warning on line 62 in dfdx-core/src/lib.rs

View workflow job for this annotation

GitHub Actions / cargo-check

unresolved link to `tensor::Cuda`
//!
//! Both devices implement [Default], you can also create them with a certain seed
//! and ordinal.
Expand All @@ -85,8 +85,8 @@
//! | Unary Operations | `a.sqrt()` | `a.sqrt()` | `a.sqrt()` |
//! | Binary Operations | `a + b` | `a + b` | `a + b` |
//! | gemm/gemv | [tensor_ops::matmul] | `a @ b` | `a @ b` |
//! | 2d Convolution | [tensor_ops::TryConv2D] | - | `torch.conv2d` |

Check warning on line 88 in dfdx-core/src/lib.rs

View workflow job for this annotation

GitHub Actions / cargo-check

unresolved link to `tensor_ops::TryConv2D`
//! | 2d Transposed Convolution | [tensor_ops::TryConvTrans2D] | - | `torch.conv_transpose2d` |

Check warning on line 89 in dfdx-core/src/lib.rs

View workflow job for this annotation

GitHub Actions / cargo-check

unresolved link to `tensor_ops::TryConvTrans2D`
//! | Slicing | [tensor_ops::slice] | `a[...]` | `a[...]` |
//! | Select | [tensor_ops::SelectTo] | `a[...]` | `torch.select` |
//! | Gather | [tensor_ops::GatherTo] | `np.take` | `torch.gather` |
Expand Down Expand Up @@ -128,44 +128,6 @@
pub use crate::tensor_ops::*;
}

/// Sets a CPU `sse` flag to flush denormal floating point numbers to zero. The opposite of this is [keep_denormals()].
///
/// Some resources:
/// 1. [Effects of Flush-To-Zero mode](https://developer.arm.com/documentation/dui0473/c/neon-and-vfp-programming/the-effects-of-using-flush-to-zero-mode?lang=en)
/// 2. [When to use Flush-To-Zero mode](https://developer.arm.com/documentation/dui0473/c/neon-and-vfp-programming/when-to-use-flush-to-zero-mode?lang=en)
pub fn flush_denormals_to_zero() {
#[cfg(all(target_arch = "x86", target_feature = "sse"))]
{
use std::arch::x86::{_MM_FLUSH_ZERO_ON, _MM_SET_FLUSH_ZERO_MODE};
unsafe { _MM_SET_FLUSH_ZERO_MODE(_MM_FLUSH_ZERO_ON) }
}

#[cfg(all(target_arch = "x86_64", target_feature = "sse"))]
{
use std::arch::x86_64::{_MM_FLUSH_ZERO_ON, _MM_SET_FLUSH_ZERO_MODE};
unsafe { _MM_SET_FLUSH_ZERO_MODE(_MM_FLUSH_ZERO_ON) }
}
}

/// Sets a CPU flag to keep denormal floating point numbers. The opposite of this is [flush_denormals_to_zero()].
///
/// Some resources:
/// 1. [Effects of Flush-To-Zero mode](https://developer.arm.com/documentation/dui0473/c/neon-and-vfp-programming/the-effects-of-using-flush-to-zero-mode?lang=en)
/// 2. [When to use Flush-To-Zero mode](https://developer.arm.com/documentation/dui0473/c/neon-and-vfp-programming/when-to-use-flush-to-zero-mode?lang=en)
pub fn keep_denormals() {
#[cfg(all(target_arch = "x86", target_feature = "sse"))]
{
use std::arch::x86::{_MM_FLUSH_ZERO_OFF, _MM_SET_FLUSH_ZERO_MODE};
unsafe { _MM_SET_FLUSH_ZERO_MODE(_MM_FLUSH_ZERO_OFF) }
}

#[cfg(all(target_arch = "x86_64", target_feature = "sse"))]
{
use std::arch::x86_64::{_MM_FLUSH_ZERO_OFF, _MM_SET_FLUSH_ZERO_MODE};
unsafe { _MM_SET_FLUSH_ZERO_MODE(_MM_FLUSH_ZERO_OFF) }
}
}

#[cfg(test)]
pub(crate) mod tests {
pub use num_traits::{Float, NumCast, Zero};
Expand Down
2 changes: 1 addition & 1 deletion dfdx-core/src/tensor/gradients.rs
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ impl<E, D: Storage<E>> Gradients<E, D> {
#[inline]
pub(crate) fn many_and_ref<L: Shape, R: Shape>(
&mut self,
ls: &Vec<impl Tensorlike<L, E, D>>,
ls: &[impl Tensorlike<L, E, D>],
r: &impl Tensorlike<R, E, D>,
) -> (Vec<&mut D::Vec>, &D::Vec) {
for i in 0..ls.len() {
Expand Down
2 changes: 2 additions & 0 deletions dfdx-core/src/tensor_ops/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@ pub(super) mod optim;
mod permute_to;
mod pow;
mod prelu;
mod prodigy;
mod realize_to;
mod recip;
mod relu;
Expand Down Expand Up @@ -255,6 +256,7 @@ pub use optim::*;
pub use permute_to::PermuteTo;
pub use pow::{powf, powi};
pub use prelu::{leakyrelu, prelu, TryPReLU};
pub use prodigy::ProdigyConfig;
pub use realize_to::RealizeTo;
pub use recip::recip;
pub use relu::relu;
Expand Down
252 changes: 252 additions & 0 deletions dfdx-core/src/tensor_ops/prodigy/cpu_kernel.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,252 @@
use super::super::WeightDecay;
use super::{ProdigyConfig, ProdigyKernel};
use crate::{
dtypes::{Dtype, NotMixedPrecision},
tensor::{Cpu, Error},
};

#[cfg(feature = "f16")]
use crate::dtypes::{f16, AMP};

#[cfg(feature = "f16")]
impl ProdigyKernel<AMP<f16>> for Cpu {
fn prodigy_kernel(
&self,
k: i32,
d: &mut f64,
d_max: &mut f64,
d_numerator: &mut f64,
cfg: &ProdigyConfig,
param: &mut Self::Vec,
s: &mut Self::Vec,
p0: &mut Self::Vec,
p0b: &mut Self::Vec,
moment1: &mut Self::Vec,
moment2: &mut Self::Vec,
grad: &Self::Vec,
) -> Result<(), Error> {
let mut d_denom_: f32 = 0.;
let [beta1, beta2] = cfg.betas.map(|x| x as f32);
let beta3 = cfg.beta3.unwrap_or_else(|| cfg.betas[1].sqrt()) as f32;

let bias_correction = if cfg.use_bias_correction {
// note: in here the first k = 1, whereas on the reference python code it's 0
(1. - beta2.powi(k)).sqrt() / (1. - beta1.powi(k))
} else {
1.
};
let mut d_ = *d as f32;
let mut d_max_ = *d_max as f32;
let mut d_numerator_ = *d_numerator as f32 * beta3;
let d0 = cfg.d0 as f32;
let lr = cfg.lr as f32;

let dlr = d_ * lr * bias_correction;

for ((((((p, g), s), p0), p0b), m), v) in param
.iter_mut()
.zip(grad.iter().cloned())
.zip(s.iter_mut())
.zip(p0.iter_mut())
.zip(p0b.iter_mut())
.zip(moment1.iter_mut())
.zip(moment2.iter_mut())
{
let p_ = p.0.to_f32();
let mut g_ = g.0.to_f32();
let mut s_ = s.0.to_f32();
let p0b_ = p0b.0.to_f32();
let mut m_ = m.0.to_f32();
let mut v_ = v.0.to_f32();

// initialize p0 if needed
if p0b_ == 0. {
p0b.0 = f16::from_f32(1.);
*p0 = *p;
}
let p0_ = p0.0.to_f32();

if let Some(WeightDecay::L2(wd)) = cfg.weight_decay {
g_ += wd as f32 * p_;
}

if lr > 0. {
d_numerator_ += (d_ / d0) * dlr * (g_ * (p0_ - p_));

m_ = m_ * beta1 + d_ * g_ * (1. - beta1);
v_ = v_ * beta2 + d_ * d_ * g_ * g_ * (1. - beta2);
m.0 = f16::from_f32(m_);
v.0 = f16::from_f32(v_);

if cfg.safeguard_warmup {
s_ = s_ * beta3 + g_ * d_.powi(2) / d0;
} else {
s_ = s_ * beta3 + g_ * d_ * dlr / d0;
}
s.0 = f16::from_f32(s_);

d_denom_ += s_.abs();
}
}

if d_denom_ == 0. {
return Ok(());
}

let global_d_numerator = d_numerator_;
let global_d_denom = d_denom_;
if lr > 0. {
let d_coef = cfg.d_coef as f32;
let d_hat_ = d_coef * global_d_numerator / global_d_denom;
if d_ == d0 {
d_ = d_.max(d_hat_);
}
d_max_ = d_max_.max(d_hat_);
let growth_rate = cfg.growth_rate as f32;
d_ = d_max_.min(d_ * growth_rate);
}

*d = d_ as f64;
*d_max = d_max_ as f64;
*d_numerator = global_d_numerator as f64;

let eps = cfg.eps as f32;

for (p, (m, v)) in param
.iter_mut()
.zip(moment1.iter_mut().zip(moment2.iter_mut()))
{
let mut p_ = p.0.to_f32();
let m_ = m.0.to_f32();
let v_ = v.0.to_f32();

let denom = v_.sqrt() + d_ * eps;

if let Some(WeightDecay::Decoupled(wd)) = cfg.weight_decay {
p_ *= 1. - wd as f32 * dlr;
}

p_ -= dlr * m_ / denom;
p.0 -= f16::from_f32(p_);
}

Ok(())
}
}

impl<E: num_traits::Float + Dtype + NotMixedPrecision> ProdigyKernel<E> for Cpu {
fn prodigy_kernel(
&self,
k: i32,
d: &mut f64,
d_max: &mut f64,
d_numerator: &mut f64,
cfg: &ProdigyConfig,
param: &mut Self::Vec,
s: &mut Self::Vec,
p0: &mut Self::Vec,
p0b: &mut Self::Vec,
moment1: &mut Self::Vec,
moment2: &mut Self::Vec,
grad: &Self::Vec,
) -> Result<(), Error> {
let mut d_denom_: E = E::zero();
let [beta1, beta2] = cfg.betas.map(E::from_f64).map(Option::unwrap);

#[allow(unused_imports)]
let beta3 = E::from_f64(cfg.beta3.unwrap_or_else(|| {
#[cfg(feature = "no-std")]
use num_traits::Float;

cfg.betas[1].sqrt()
}))
.unwrap();

let bias_correction = if cfg.use_bias_correction {
// note: in here the first k = 1, whereas on the reference python code it's 0
(E::one() - beta2.powi(k)).sqrt() / (E::one() - beta1.powi(k))
} else {
E::one()
};
let mut d_ = E::from_f64(*d).unwrap();
let mut d_max_ = E::from_f64(*d_max).unwrap();
let mut d_numerator_ = E::from_f64(*d_numerator).unwrap() * beta3;
let d0 = E::from_f64(cfg.d0).unwrap();
let lr = E::from_f64(cfg.lr).unwrap();

let dlr = d_ * lr * bias_correction;

for ((((((p, mut g), s), p0), p0b), m), v) in param
.iter_mut()
.zip(grad.iter().cloned())
.zip(s.iter_mut())
.zip(p0.iter_mut())
.zip(p0b.iter_mut())
.zip(moment1.iter_mut())
.zip(moment2.iter_mut())
{
// initialize p0 if needed
if *p0b == E::zero() {
*p0b = E::one();
*p0 = *p;
}

if let Some(WeightDecay::L2(wd)) = cfg.weight_decay {
g += E::from_f64(wd).unwrap() * *p;
}

if lr > E::zero() {
d_numerator_ += (d_ / d0) * dlr * (g * (*p0 - *p));

*m = *m * beta1 + d_ * g * (E::one() - beta1);
*v = *v * beta2 + d_ * d_ * g * g * (E::one() - beta2);

if cfg.safeguard_warmup {
*s = *s * beta3 + g * d_.powi(2) / d0
} else {
*s = *s * beta3 + g * d_ * dlr / d0
}

d_denom_ += s.abs();
}
}

if d_denom_ == E::zero() {
return Ok(());
}

let global_d_numerator = d_numerator_;
let global_d_denom = d_denom_;
if lr > E::zero() {
let d_coef = E::from_f64(cfg.d_coef).unwrap();
let d_hat_ = d_coef * global_d_numerator / global_d_denom;
if d_ == d0 {
d_ = d_.max(d_hat_);
}
d_max_ = d_max_.max(d_hat_);
let growth_rate = E::from_f64(cfg.growth_rate).unwrap();
d_ = d_max_.min(d_ * growth_rate);
}

*d = d_.to_f64().unwrap();
*d_max = d_max_.to_f64().unwrap();
*d_numerator = global_d_numerator.to_f64().unwrap();

let eps = E::from_f64(cfg.eps).unwrap();

for (p, (m, v)) in param
.iter_mut()
.zip(moment1.iter_mut().zip(moment2.iter_mut()))
{
let denom = v.sqrt() + d_ * eps;

if let Some(WeightDecay::Decoupled(wd)) = cfg.weight_decay {
*p *= E::one() - E::from_f64(wd).unwrap() * dlr;
}

*p -= dlr * *m / denom;
}

Ok(())
}
}
Loading
Loading