diff --git a/dfdx-core/src/tensor_ops/mod.rs b/dfdx-core/src/tensor_ops/mod.rs index c51040ee..a229425d 100644 --- a/dfdx-core/src/tensor_ops/mod.rs +++ b/dfdx-core/src/tensor_ops/mod.rs @@ -186,6 +186,7 @@ pub(super) mod optim; mod permute_to; mod pow; mod prelu; +mod prodigy; mod realize_to; mod recip; mod relu; @@ -250,6 +251,7 @@ pub use optim::*; pub use permute_to::PermuteTo; pub use pow::{powf, powi}; pub use prelu::{leakyrelu, prelu, TryPReLU}; +pub use prodigy::ProdigyConfig; pub use realize_to::RealizeTo; pub use recip::recip; pub use relu::relu; diff --git a/dfdx-core/src/tensor_ops/prodigy/cpu_kernel.rs b/dfdx-core/src/tensor_ops/prodigy/cpu_kernel.rs new file mode 100644 index 00000000..6cc2d6be --- /dev/null +++ b/dfdx-core/src/tensor_ops/prodigy/cpu_kernel.rs @@ -0,0 +1,244 @@ +use super::super::WeightDecay; +use super::{ProdigyConfig, ProdigyKernel}; +use crate::{ + dtypes::{Dtype, NotMixedPrecision}, + tensor::{Cpu, Error}, +}; + +#[cfg(feature = "f16")] +use crate::dtypes::{f16, AMP}; + +#[cfg(feature = "f16")] +impl ProdigyKernel> for Cpu { + fn prodigy_kernel( + &self, + k: i32, + d: &mut f64, + d_max: &mut f64, + d_numerator: &mut f64, + cfg: &ProdigyConfig, + param: &mut Self::Vec, + s: &mut Self::Vec, + p0: &mut Self::Vec, + p0b: &mut Self::Vec, + moment1: &mut Self::Vec, + moment2: &mut Self::Vec, + grad: &Self::Vec, + ) -> Result<(), Error> { + let mut d_denom_: f32 = 0.; + let [beta1, beta2] = cfg.betas.map(|x| x as f32); + let beta3 = cfg.beta3.unwrap_or_else(|| cfg.betas[1].sqrt()) as f32; + + let bias_correction = if cfg.use_bias_correction { + // note: in here the first k = 1, whereas on the reference python code it's 0 + (1. - beta2.powi(k)).sqrt() / (1. - beta1.powi(k)) + } else { + 1. + }; + let mut d_ = *d as f32; + let mut d_max_ = *d_max as f32; + let mut d_numerator_ = *d_numerator as f32 * beta3; + let d0 = cfg.d0 as f32; + let lr = cfg.lr as f32; + + let dlr = d_ * lr * bias_correction; + + for ((((((p, g), s), p0), p0b), m), v) in param + .iter_mut() + .zip(grad.iter().cloned()) + .zip(s.iter_mut()) + .zip(p0.iter_mut()) + .zip(p0b.iter_mut()) + .zip(moment1.iter_mut()) + .zip(moment2.iter_mut()) + { + let p_ = p.0.to_f32(); + let mut g_ = g.0.to_f32(); + let mut s_ = s.0.to_f32(); + let p0b_ = p0b.0.to_f32(); + let mut m_ = m.0.to_f32(); + let mut v_ = v.0.to_f32(); + + // initialize p0 if needed + if p0b_ == 0. { + p0b.0 = f16::from_f32(1.); + *p0 = *p; + } + let p0_ = p0.0.to_f32(); + + if let Some(WeightDecay::L2(wd)) = cfg.weight_decay { + g_ += wd as f32 * p_; + } + + if lr > 0. { + d_numerator_ += (d_ / d0) * dlr * (g_ * (p0_ - p_)); + + m_ = m_ * beta1 + d_ * g_ * (1. - beta1); + v_ = v_ * beta2 + d_ * d_ * g_ * g_ * (1. - beta2); + m.0 = f16::from_f32(m_); + v.0 = f16::from_f32(v_); + + if cfg.safeguard_warmup { + s_ = s_ * beta3 + g_ * d_.powi(2) / d0; + } else { + s_ = s_ * beta3 + g_ * d_ * dlr / d0; + } + s.0 = f16::from_f32(s_); + + d_denom_ += s_.abs(); + } + } + + if d_denom_ == 0. { + return Ok(()); + } + + let global_d_numerator = d_numerator_; + let global_d_denom = d_denom_; + if lr > 0. { + let d_coef = cfg.d_coef as f32; + let d_hat_ = d_coef * global_d_numerator / global_d_denom; + if d_ == d0 { + d_ = d_.max(d_hat_); + } + d_max_ = d_max_.max(d_hat_); + let growth_rate = cfg.growth_rate as f32; + d_ = d_max_.min(d_ * growth_rate); + } + + *d = d_ as f64; + *d_max = d_max_ as f64; + *d_numerator = global_d_numerator as f64; + + let eps = cfg.eps as f32; + + for (p, (m, v)) in param + .iter_mut() + .zip(moment1.iter_mut().zip(moment2.iter_mut())) + { + let mut p_ = p.0.to_f32(); + let m_ = m.0.to_f32(); + let v_ = v.0.to_f32(); + + let denom = v_.sqrt() + d_ * eps; + + if let Some(WeightDecay::Decoupled(wd)) = cfg.weight_decay { + p_ *= 1. - wd as f32 * dlr; + } + + p_ -= dlr * m_ / denom; + p.0 -= f16::from_f32(p_); + } + + Ok(()) + } +} + +impl ProdigyKernel for Cpu { + fn prodigy_kernel( + &self, + k: i32, + d: &mut f64, + d_max: &mut f64, + d_numerator: &mut f64, + cfg: &ProdigyConfig, + param: &mut Self::Vec, + s: &mut Self::Vec, + p0: &mut Self::Vec, + p0b: &mut Self::Vec, + moment1: &mut Self::Vec, + moment2: &mut Self::Vec, + grad: &Self::Vec, + ) -> Result<(), Error> { + let mut d_denom_: E = E::zero(); + let [beta1, beta2] = cfg.betas.map(E::from_f64).map(Option::unwrap); + let beta3 = E::from_f64(cfg.beta3.unwrap_or_else(|| cfg.betas[1].sqrt())).unwrap(); + + let bias_correction = if cfg.use_bias_correction { + // note: in here the first k = 1, whereas on the reference python code it's 0 + (E::one() - beta2.powi(k)).sqrt() / (E::one() - beta1.powi(k)) + } else { + E::one() + }; + let mut d_ = E::from_f64(*d).unwrap(); + let mut d_max_ = E::from_f64(*d_max).unwrap(); + let mut d_numerator_ = E::from_f64(*d_numerator).unwrap() * beta3; + let d0 = E::from_f64(cfg.d0).unwrap(); + let lr = E::from_f64(cfg.lr).unwrap(); + + let dlr = d_ * lr * bias_correction; + + for ((((((p, mut g), s), p0), p0b), m), v) in param + .iter_mut() + .zip(grad.iter().cloned()) + .zip(s.iter_mut()) + .zip(p0.iter_mut()) + .zip(p0b.iter_mut()) + .zip(moment1.iter_mut()) + .zip(moment2.iter_mut()) + { + // initialize p0 if needed + if *p0b == E::zero() { + *p0b = E::one(); + *p0 = *p; + } + + if let Some(WeightDecay::L2(wd)) = cfg.weight_decay { + g += E::from_f64(wd).unwrap() * *p; + } + + if lr > E::zero() { + d_numerator_ += (d_ / d0) * dlr * (g * (*p0 - *p)); + + *m = *m * beta1 + d_ * g * (E::one() - beta1); + *v = *v * beta2 + d_ * d_ * g * g * (E::one() - beta2); + + if cfg.safeguard_warmup { + *s = *s * beta3 + g * d_.powi(2) / d0 + } else { + *s = *s * beta3 + g * d_ * dlr / d0 + } + + d_denom_ += s.abs(); + } + } + + if d_denom_ == E::zero() { + return Ok(()); + } + + let global_d_numerator = d_numerator_; + let global_d_denom = d_denom_; + if lr > E::zero() { + let d_coef = E::from_f64(cfg.d_coef).unwrap(); + let d_hat_ = d_coef * global_d_numerator / global_d_denom; + if d_ == d0 { + d_ = d_.max(d_hat_); + } + d_max_ = d_max_.max(d_hat_); + let growth_rate = E::from_f64(cfg.growth_rate).unwrap(); + d_ = d_max_.min(d_ * growth_rate); + } + + *d = d_.to_f64().unwrap(); + *d_max = d_max_.to_f64().unwrap(); + *d_numerator = global_d_numerator.to_f64().unwrap(); + + let eps = E::from_f64(cfg.eps).unwrap(); + + for (p, (m, v)) in param + .iter_mut() + .zip(moment1.iter_mut().zip(moment2.iter_mut())) + { + let denom = v.sqrt() + d_ * eps; + + if let Some(WeightDecay::Decoupled(wd)) = cfg.weight_decay { + *p *= E::one() - E::from_f64(wd).unwrap() * dlr; + } + + *p -= dlr * *m / denom; + } + + Ok(()) + } +} diff --git a/dfdx-core/src/tensor_ops/prodigy/cuda_kernel.rs b/dfdx-core/src/tensor_ops/prodigy/cuda_kernel.rs new file mode 100644 index 00000000..7dee42f2 --- /dev/null +++ b/dfdx-core/src/tensor_ops/prodigy/cuda_kernel.rs @@ -0,0 +1,229 @@ +use crate::{ + dtypes::*, + prelude::Storage, + tensor::{launch_cfg, Cuda, Error}, + tensor_ops::optim::*, +}; + +use cudarc::driver::{DeviceRepr, DeviceSlice, LaunchAsync}; + +#[repr(C)] +#[derive(Clone)] +struct CudaProdigyConfig1 { + numel: usize, + k: i32, + lr: f64, + beta1: f64, + beta2: f64, + beta3: f64, + weight_decay_type: WeightDecayType, + weight_decay: f64, + bias_correction: f64, + safeguard_warmup: bool, + d0: f64, +} + +#[repr(C)] +#[derive(Clone)] +struct CudaProdigyConfig2 { + numel: usize, + lr: f64, + eps: f64, + weight_decay_type: WeightDecayType, + weight_decay: f64, + bias_correction: f64, +} + +unsafe impl DeviceRepr for CudaProdigyConfig1 { + fn as_kernel_param(&self) -> *mut std::ffi::c_void { + self as *const Self as *mut _ + } +} + +unsafe impl DeviceRepr for CudaProdigyConfig2 { + fn as_kernel_param(&self) -> *mut std::ffi::c_void { + self as *const Self as *mut _ + } +} + +impl CudaProdigyConfig1 { + fn new(cfg: &super::ProdigyConfig, numel: usize, k: i32) -> Self { + let [beta1, beta2] = cfg.betas; + let beta3 = if let Some(beta3) = cfg.beta3 { + beta3 + } else { + beta2.sqrt() + }; + let (weight_decay_type, weight_decay) = weight_decay_to_cuda(cfg.weight_decay); + + let bias_correction = if cfg.use_bias_correction { + (1.0 - beta2.powi(k)).sqrt() / (1.0 - beta1.powi(k)) + } else { + 1. + }; + + CudaProdigyConfig1 { + numel, + k, + lr: cfg.lr, + beta1, + beta2, + beta3, + weight_decay_type, + weight_decay, + bias_correction, + safeguard_warmup: cfg.safeguard_warmup, + d0: cfg.d0, + } + } +} + +impl CudaProdigyConfig2 { + fn new(cfg: &super::ProdigyConfig, cfg1: &CudaProdigyConfig1) -> Self { + CudaProdigyConfig2 { + numel: cfg1.numel, + lr: cfg1.lr, + eps: cfg.eps, + weight_decay_type: cfg1.weight_decay_type, + weight_decay: cfg1.weight_decay, + bias_correction: cfg1.bias_correction, + } + } +} + +const PTX_SRC: &str = include_str!(concat!(env!("OUT_DIR"), "/prodigy.ptx")); + +trait HasCudaKernel { + const MOD: &'static str; + const FWD1: &'static str; + const FWD2: &'static str; +} + +#[cfg(feature = "f16")] +impl HasCudaKernel> for Cuda { + const MOD: &'static str = "prodigy_amp_f16"; + const FWD1: &'static str = "prodigy_update1_amp_f16"; + const FWD2: &'static str = "prodigy_update2_amp_f16"; +} + +#[cfg(feature = "f16")] +impl HasCudaKernel for Cuda { + const MOD: &'static str = "prodigy_f16"; + const FWD1: &'static str = "prodigy_update1_f16"; + const FWD2: &'static str = "prodigy_update2_f16"; +} + +impl HasCudaKernel for Cuda { + const MOD: &'static str = "prodigy_f32"; + const FWD1: &'static str = "prodigy_update1_f32"; + const FWD2: &'static str = "prodigy_update2_f32"; +} + +impl HasCudaKernel for Cuda { + const MOD: &'static str = "prodigy_f64"; + const FWD1: &'static str = "prodigy_update1_f64"; + const FWD2: &'static str = "prodigy_update2_f64"; +} + +impl super::ProdigyKernel for Cuda +where + Self: HasCudaKernel, +{ + fn prodigy_kernel( + &self, + k: i32, + d: &mut f64, + d_max: &mut f64, + d_numerator: &mut f64, + cfg: &super::ProdigyConfig, + param: &mut Self::Vec, + s: &mut Self::Vec, + p0: &mut Self::Vec, + p0b: &mut Self::Vec, + moment1: &mut Self::Vec, + moment2: &mut Self::Vec, + grad: &Self::Vec, + ) -> Result<(), Error> { + if !self.dev.has_func(Self::MOD, Self::FWD1) { + self.dev + .load_ptx(PTX_SRC.into(), Self::MOD, &[Self::FWD1, Self::FWD2])?; + } + + let numel = param.len(); + let opt_cfg1 = CudaProdigyConfig1::new(cfg, numel, k); + let opt_cfg2 = CudaProdigyConfig2::new(cfg, &opt_cfg1); + let func1 = self.dev.get_func(Self::MOD, Self::FWD1).unwrap(); + let cu_cfg = launch_cfg::<128>(numel as u32); + + // d_numerators for thread-block sum-reduction + let mut d_numerators: Self::Vec = + self.try_alloc_len(cu_cfg.grid_dim.0 as usize * cu_cfg.block_dim.0 as usize)?; + let mut d_numerators_vec = + vec![E::zero(); cu_cfg.grid_dim.0 as usize * cu_cfg.block_dim.0 as usize]; + self.dev + .htod_sync_copy_into(d_numerators_vec.as_slice(), &mut d_numerators)?; + // d_denom for thread-block sum-reduction + let mut d_denoms: Self::Vec = + self.try_alloc_len(cu_cfg.grid_dim.0 as usize * cu_cfg.block_dim.0 as usize)?; + let mut d_denoms_vec = + vec![E::zero(); cu_cfg.grid_dim.0 as usize * cu_cfg.block_dim.0 as usize]; + self.dev + .htod_sync_copy_into(d_denoms_vec.as_slice(), &mut d_denoms)?; + + // local cache + let d_old = *d; + let beta3 = opt_cfg1.beta3; + + let params1 = ( + opt_cfg1, + *d, + // + &mut d_numerators, + &mut d_denoms, + // + &*param, + s, + p0, + p0b, + &mut *moment1, + &mut *moment2, + grad, + ); + unsafe { func1.launch(cu_cfg.clone(), params1) }?; + + // get the thread-block d_numerators and d_denoms + self.dev + .dtoh_sync_copy_into(&d_numerators.data, d_numerators_vec.as_mut_slice())?; + self.dev + .dtoh_sync_copy_into(&d_denoms.data, d_denoms_vec.as_mut_slice())?; + // sum and update d_numerators and d_denoms + let d_numerator_: E = d_numerators_vec + .into_iter() + .reduce(|acc, e| acc + e) + .unwrap() + + E::from_f64(*d_numerator).unwrap() * E::from_f64(beta3).unwrap(); + let d_denom_: E = d_denoms_vec.into_iter().reduce(|acc, e| acc + e).unwrap(); + + if d_denom_ == E::zero() { + return Ok(()); + } + + let func2 = self.dev.get_func(Self::MOD, Self::FWD2).unwrap(); + + *d_numerator = d_numerator_.to_f64().unwrap(); + let global_d_denom = d_denom_.to_f64().unwrap(); + if cfg.lr > 0. { + let d_hat = cfg.d_coef * *d_numerator / global_d_denom; + if *d == cfg.d0 { + *d = d.max(d_hat); + } + *d_max = d_max.max(d_hat); + *d = d_max.min(*d * cfg.growth_rate); + } + + let params2 = (opt_cfg2, d_old, param, &*moment1, &*moment2); + unsafe { func2.launch(cu_cfg, params2) }?; + + Ok(()) + } +} diff --git a/dfdx-core/src/tensor_ops/prodigy/mod.rs b/dfdx-core/src/tensor_ops/prodigy/mod.rs new file mode 100644 index 00000000..7b331bb7 --- /dev/null +++ b/dfdx-core/src/tensor_ops/prodigy/mod.rs @@ -0,0 +1,148 @@ +mod cpu_kernel; + +#[cfg(feature = "cuda")] +mod cuda_kernel; + +use crate::{ + shapes::{Dtype, Shape}, + tensor::{Error, Storage, Tensor}, +}; + +/// Configuration of hyperparameters for Prodigy. +/// +/// Changing some default parameters: +/// ```rust +/// # use dfdx_core::prelude::*; +/// AdamConfig { +/// d_coef: 0.5, // smaller learning rate +/// weight_decay: Some(WeightDecay::L2(1e-1)), +/// ..Default::default() +/// }; +/// ``` +#[derive(Debug, Clone, Copy)] +pub struct ProdigyConfig { + /// Learning rate adjustment parameter. + /// Increases or decreases the Prodigy learning rate. + /// + /// Defaults to `1.0`. + pub lr: f64, + + /// Betas coefficients used for computing running averages of gradient and its square. + /// + /// Defaults to `[0.9, 0.999]`. + pub betas: [f64; 2], + + /// Coefficients for computing the Prodidy stepsize using running averages. + /// If set to `None`, uses the value of square root of beta2 (ie. betas[1]). + /// + /// Defaults to `None`. + pub beta3: Option, + + /// Term added to the denominator outside of the root operation to improve numerical stability. + /// + /// Defaults to `1e-8`. + pub eps: f64, + + /// Optional weight decay. + /// + /// Defaults to `None`. + pub weight_decay: Option, + + /// Turn on Adam's bias correction. + /// + /// Defaults to `false`. + pub use_bias_correction: bool, + + /// Remove lr from the denominator of D estimate to avoid issues during warm-up stage. + /// + /// Defaults to `false`. + pub safeguard_warmup: bool, + + /// Initial D estimate for D-adaptation. Rarely needs changing. + /// + /// Defaults to `1e-6`. + pub d0: f64, + + /// Coefficient in the expression for the estimate of d. + /// Values such as `0.5` and `2.0` typically work as well for a smaller or higher learning rate, respectively. + /// Changing this parameter is the preferred way to tune the optimizer. + /// + /// Defaults to `1.0`. + pub d_coef: f64, + + /// Prevent the D estimate from growing faster than this multiplicative rate. + /// Use infinite for unrestricted. Values like 1.02 give a kind of learning + /// rate warmup effect. + /// + /// Defaults to `f64::INFINITY`. + pub growth_rate: f64, +} + +impl Default for ProdigyConfig { + fn default() -> Self { + Self { + lr: 1.0, + betas: [0.9, 0.999], + beta3: None, + eps: 1e-8, + weight_decay: None, + use_bias_correction: false, + safeguard_warmup: false, + d0: 1e-6, + d_coef: 1.0, + growth_rate: f64::INFINITY, + } + } +} + +pub trait ProdigyKernel: Storage { + #[allow(clippy::too_many_arguments)] + fn prodigy_kernel( + &self, + k: i32, + d: &mut f64, + d_max: &mut f64, + d_numerator: &mut f64, + cfg: &ProdigyConfig, + param: &mut Self::Vec, + s: &mut Self::Vec, + p0: &mut Self::Vec, + p0b: &mut Self::Vec, + moment1: &mut Self::Vec, + moment2: &mut Self::Vec, + grad: &Self::Vec, + ) -> Result<(), Error>; +} + +impl ProdigyConfig { + #[allow(clippy::too_many_arguments)] + pub fn try_update>( + &self, + k: i32, + d: &mut f64, + d_max: &mut f64, + d_numerator: &mut f64, + param: &mut Tensor, + s: &mut D::Vec, + p0: &mut D::Vec, + p0b: &mut D::Vec, + moment1: &mut D::Vec, + moment2: &mut D::Vec, + grad: &D::Vec, + ) -> Result<(), crate::tensor::Error> { + param.device.prodigy_kernel( + k, + d, + d_max, + d_numerator, + self, + std::sync::Arc::make_mut(&mut param.data), + s, + p0, + p0b, + moment1, + moment2, + grad, + ) + } +} diff --git a/dfdx-core/src/tensor_ops/prodigy/prodigy.cu b/dfdx-core/src/tensor_ops/prodigy/prodigy.cu new file mode 100644 index 00000000..c45e2ed6 --- /dev/null +++ b/dfdx-core/src/tensor_ops/prodigy/prodigy.cu @@ -0,0 +1,342 @@ +#include "cuda_utils.cuh" + +enum WeightDecayType { + None, + L2, + Decoupled +}; + +struct ProdigyConfig1 { + const size_t numel; + const int k_int; + // + double lr; + double beta1; + double beta2; + double beta3; + WeightDecayType weight_decay_type; + double weight_decay; + double bias_correction; + bool safeguard_warmup; + double d0; +}; + +template +__device__ void prodigy_update1( + const ProdigyConfig1 cfg, + const double d, + + // temporaries for sum-reduction. + // are written into and read back by host + T* d_numerators, + T* d_denoms, + + // parameter-related tensors. + // some are overwritten + const T* param, + T* s, + T* p0, + T* p0b, + T* moment1, + T* moment2, + const T* grad +) { + const size_t numel = cfg.numel; + const int k_int = cfg.k_int; + T lr = cfg.lr; + T beta1 = cfg.beta1; + T beta2 = cfg.beta2; + T beta3 = cfg.beta3; + T weight_decay = cfg.weight_decay; + T bias_correction = cfg.bias_correction; + T d0 = cfg.d0; + T zero = 0.0; + T one = 1.0; + T k = k_int; + + unsigned int cu_index = blockIdx.x * blockDim.x + threadIdx.x; + unsigned int cu_stride = blockDim.x * gridDim.x; + + if (cu_index >= numel) { + return; + } + + T d_ = d; + T dlr = d_ * lr * bias_correction; + + // thread-local d_numerator and d_denom + T d_numerator_ = zero; + T d_denom = zero; + // those values will be sum-reduced by all threads and all blocks + + for (unsigned int i = cu_index; i < numel; i += cu_stride) { + T p_ = param[i]; + T g = grad[i]; + T s_ = s[i]; + T m_ = moment1[i]; + T v_ = moment2[i]; + + // initialize p0 if needed + if (p0b[i] == zero) { + p0b[i] = one; + p0[i] = p_; + } + T p0_ = p0[i]; + + if (cfg.weight_decay_type == L2) { + g += weight_decay * p_; + } + + if (lr > zero) { + d_numerator_ += (d_ / d0) * dlr * (g * (p0_ - p_)); + + m_ = m_ * beta1 + d_ * g * (one - beta1); + v_ = v_ * beta2 + d_ * d_ * g * g * (one - beta2); + + if (cfg.safeguard_warmup) { + s_ = s_ * beta3 + g * d_ * d_ / d0; + } else { + s_ = s_ * beta3 + g * d_ * dlr / d0; + } + + d_denom += absg(s_); + } + + s[i] = s_; + moment1[i] = m_; + moment2[i] = v_; + } + + // prepares the values for sum-reduction + d_numerators[cu_index] = d_numerator_; + d_denoms[cu_index] = d_denom; + + return; +} + +#define PRODIGY1(TYPENAME, FN1) \ +extern "C" __global__ void FN1( \ + const ProdigyConfig1 cfg, \ + const double d, \ + TYPENAME* d_numerators, \ + TYPENAME* d_denoms, \ + const TYPENAME* param, \ + TYPENAME* s, \ + TYPENAME* p0, \ + TYPENAME* p0b, \ + TYPENAME* moment1, \ + TYPENAME* moment2, \ + const TYPENAME* grad \ +) { \ + prodigy_update1(cfg, d, d_numerators, d_denoms, param, s, p0, p0b, moment1, moment2, grad); \ +} + +PRODIGY1(__half, prodigy_update1_f16); +PRODIGY1(float, prodigy_update1_f32); +PRODIGY1(double, prodigy_update1_f64); + + + +extern "C" __global__ void prodigy_update1_amp_f16( + const ProdigyConfig1 cfg, + const double d, + + // temporaries for sum-reduction. + // are written into and read back by host + __half* d_numerators, + __half* d_denoms, + + // parameter-related tensors. + // some are overwritten + const __half* param, + __half* s, + __half* p0, + __half* p0b, + __half* moment1, + __half* moment2, + const __half* grad +) { + const size_t numel = cfg.numel; + const int k_int = cfg.k_int; + float lr = cfg.lr; + float beta1 = cfg.beta1; + float beta2 = cfg.beta2; + float beta3 = cfg.beta3; + float weight_decay = cfg.weight_decay; + float bias_correction = cfg.bias_correction; + float d0 = cfg.d0; + float zero = 0.0; + __half zero_half = zero; + float one = 1.0; + float k = k_int; + + unsigned int cu_index = blockIdx.x * blockDim.x + threadIdx.x; + unsigned int cu_stride = blockDim.x * gridDim.x; + + if (cu_index >= numel) { + return; + } + + float d_ = d; + float dlr = d_ * lr * bias_correction; + + // thread-local d_numerator and d_denom + float d_numerator_ = zero; + float d_denom = zero; + // those values will be sum-reduced by all threads and all blocks + + for (unsigned int i = cu_index; i < numel; i += cu_stride) { + float p_ = param[i]; + float g = grad[i]; + float s_ = s[i]; + float m_ = moment1[i]; + float v_ = moment2[i]; + + // initialize p0 if needed + if (p0b[i] == zero_half) { + p0b[i] = one; + p0[i] = p_; + } + float p0_ = p0[i]; + + if (cfg.weight_decay_type == L2) { + g += weight_decay * p_; + } + + if (lr > zero) { + d_numerator_ += (d_ / d0) * dlr * (g * (p0_ - p_)); + + m_ = m_ * beta1 + d_ * g * (one - beta1); + v_ = v_ * beta2 + d_ * d_ * g * g * (one - beta2); + + if (cfg.safeguard_warmup) { + s_ = s_ * beta3 + g * d_ * d_ / d0; + } else { + s_ = s_ * beta3 + g * d_ * dlr / d0; + } + + d_denom += absg(s_); + } + + s[i] = s_; + moment1[i] = m_; + moment2[i] = v_; + } + + // prepares the values for sum-reduction + d_numerators[cu_index] = d_numerator_; + d_denoms[cu_index] = d_denom; + + return; +} + +struct ProdigyConfig2 { + const size_t numel; + double lr; + double eps; + WeightDecayType weight_decay_type; + double weight_decay; + double bias_correction; +}; + +template +__device__ void prodigy_update2( + const ProdigyConfig2 cfg, + const double d_old, + + // parameter-related tensors. + // some are overwritten + T* param, + const T* moment1, + const T* moment2 +) { + const size_t numel = cfg.numel; + T lr = cfg.lr; + T eps = cfg.eps; + T weight_decay = cfg.weight_decay; + T bias_correction = cfg.bias_correction; + T one = 1.0; + + T d_old_ = d_old; + T dlr_old = d_old_ * lr * bias_correction; + + unsigned int cu_index = blockIdx.x * blockDim.x + threadIdx.x; + unsigned int cu_stride = blockDim.x * gridDim.x; + + if (cu_index >= numel) { + return; + } + + for (unsigned int i = cu_index; i < numel; i += cu_stride) { + T p_ = param[i]; + T m_ = moment1[i]; + T v_ = moment2[i]; + + T denom = sqrtg(v_) + d_old_ * eps; + if (cfg.weight_decay_type == Decoupled) { + p_ *= one - weight_decay * dlr_old; + } + + p_ -= dlr_old * m_ / denom; + + param[i] = p_; + } +} + +#define PRODIGY2(TYPENAME, FN2) \ +extern "C" __global__ void FN2( \ + const ProdigyConfig2 cfg, \ + const double d, \ + TYPENAME* param, \ + const TYPENAME* moment1, \ + const TYPENAME* moment2 \ +) { \ + prodigy_update2(cfg, d, param, moment1, moment2); \ +} + +PRODIGY2(__half, prodigy_update2_f16); +PRODIGY2(float, prodigy_update2_f32); +PRODIGY2(double, prodigy_update2_f64); + +extern "C" __global__ void prodigy_update2_amp_f16( + const ProdigyConfig2 cfg, + const double d_old, + + // parameter-related tensors. + // some are overwritten + __half* param, + const __half* moment1, + const __half* moment2 +) { + const size_t numel = cfg.numel; + float lr = cfg.lr; + float eps = cfg.eps; + float weight_decay = cfg.weight_decay; + float bias_correction = cfg.bias_correction; + float one = 1.0; + + float d_old_ = d_old; + float dlr_old = d_old_ * lr * bias_correction; + + unsigned int cu_index = blockIdx.x * blockDim.x + threadIdx.x; + unsigned int cu_stride = blockDim.x * gridDim.x; + + if (cu_index >= numel) { + return; + } + + for (unsigned int i = cu_index; i < numel; i += cu_stride) { + float p_ = param[i]; + float m_ = moment1[i]; + float v_ = moment2[i]; + + float denom = sqrtg(v_) + d_old_ * eps; + if (cfg.weight_decay_type == Decoupled) { + p_ *= one - weight_decay * dlr_old; + } + + p_ -= dlr_old * m_ / denom; + + param[i] = p_; + } +} \ No newline at end of file diff --git a/dfdx-core/src/tensor_ops/utilities/device.rs b/dfdx-core/src/tensor_ops/utilities/device.rs index 2504185f..06998f69 100644 --- a/dfdx-core/src/tensor_ops/utilities/device.rs +++ b/dfdx-core/src/tensor_ops/utilities/device.rs @@ -23,6 +23,7 @@ pub trait Device: // optimizers + super::super::adam::AdamKernel + + super::super::prodigy::ProdigyKernel + super::super::sgd::SgdKernel + super::super::rmsprop::RMSpropKernel diff --git a/dfdx/src/nn/optim/mod.rs b/dfdx/src/nn/optim/mod.rs index e80e549e..b01b5d74 100644 --- a/dfdx/src/nn/optim/mod.rs +++ b/dfdx/src/nn/optim/mod.rs @@ -32,12 +32,16 @@ //! ``` mod adam; +mod prodigy; mod rmsprop; mod sgd; pub use adam::Adam; +pub use prodigy::Prodigy; pub use rmsprop::RMSprop; pub use sgd::Sgd; // re-exports pub use super::Optimizer; -pub use crate::tensor_ops::{AdamConfig, Momentum, RMSpropConfig, SgdConfig, WeightDecay}; +pub use crate::tensor_ops::{ + AdamConfig, Momentum, ProdigyConfig, RMSpropConfig, SgdConfig, WeightDecay, +}; diff --git a/dfdx/src/nn/optim/prodigy.rs b/dfdx/src/nn/optim/prodigy.rs new file mode 100644 index 00000000..87340801 --- /dev/null +++ b/dfdx/src/nn/optim/prodigy.rs @@ -0,0 +1,395 @@ +use std::marker::PhantomData; + +use crate::{ + shapes::{Dtype, Shape}, + tensor::{Error, Gradients, Storage, Tensor, Tensorlike, UniqueId}, + tensor_ops::{Device, ProdigyConfig}, +}; + +/// An implementation of the Prodigy optimizer from +/// [Prodigy: An Expeditiously Adaptive Parameter-Free Learner](https://arxiv.org/abs/2306.06101), +/// specifically _Algorithm 4, Adam version_, based on the researchers' [implementation](https://github.com/konstmish/prodigy). +/// +/// # Example Usage +/// ```rust +/// # use dfdx::prelude::*; +/// # type Model = Tensor; +/// # let dev: Cpu = Default::default(); +/// # let model: Model = dev.zeros(); +/// let mut opt: Prodigy = optim::Prodigy::new(&model, ProdigyConfig { +/// lr: 1.0, +/// betas: [0.5, 0.25], +/// eps: 1e-6, +/// weight_decay: Some(WeightDecay::Decoupled(1e-2)), +/// ..Default::default() +/// }); +/// ``` +/// +/// See module level documentation at [crate::nn::optim] for examples of how to actually use an optimizer. +#[derive(Debug, Clone)] +pub struct Prodigy> { + /// Hyperparameter configuration + pub cfg: ProdigyConfig, + + /// Timestep. + k: i32, + + // d-values that change across step updates + d: f64, + d_max: f64, + d_numerator: f64, + + s: Gradients, + + /// Initial value for a given parameter. + /// + /// For a given parameter, the initial value is observed at it's first optimizer update. + p0: Gradients, + + /// Helper data to identify whether `p0` has been initialized for a given parameter. + /// + /// - `E::zero()` indicates the `p0` value for the parameter hasn't been initialized. + /// - `E::one()` indicates the `p0` value for the parameter has been initialized. + // + // Note: this is currently expensive since a single bool per parameter would be sufficient. + p0b: Gradients, + + moment1: Gradients, + moment2: Gradients, + + marker: PhantomData<*const M>, +} + +impl> Prodigy { + /// Constructs using hyperparameters from `cfg`. + pub fn new(_model: &M, cfg: ProdigyConfig) -> Self { + Self { + cfg, + k: 0, + d: cfg.d0, + d_max: cfg.d0, + d_numerator: 0.0, + s: Gradients::leaky(), + p0: Gradients::leaky(), + p0b: Gradients::leaky(), + moment1: Gradients::leaky(), + moment2: Gradients::leaky(), + marker: PhantomData, + } + } +} + +impl> crate::nn::Optimizer for Prodigy { + fn update_tensor( + &mut self, + t: &mut Tensor, + gradients: &Gradients, + missing_params: &mut Vec, + ) -> Result<(), crate::tensor::Error> { + let g = gradients.get_ref_checked(t); + match g { + None => missing_params.push(t.id()), + Some(g) => { + let s_t = self.s.get_or_alloc_mut(t)?; + let p0_t = self.p0.get_or_alloc_mut(t)?; + let p0b_t = self.p0b.get_or_alloc_mut(t)?; + let m_t = self.moment1.get_or_alloc_mut(t)?; + let v_t = self.moment2.get_or_alloc_mut(t)?; + self.cfg.try_update( + self.k, + &mut self.d, + &mut self.d_max, + &mut self.d_numerator, + t, + s_t, + p0_t, + p0b_t, + m_t, + v_t, + g, + )?; + } + } + Ok(()) + } + + fn update(&mut self, module: &mut M, gradients: &Gradients) -> Result<(), Error> + where + M: crate::nn::UpdateParams, + { + self.k = self.k.checked_add(1).unwrap(); + + // NOTE: the rest of this is identical to default implementation of update. + let mut missing_tensors = Vec::new(); + module.try_update_params(self, gradients, &mut missing_tensors)?; + if missing_tensors.is_empty() { + Ok(()) + } else { + Err(Error::UnusedTensors(missing_tensors)) + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{prelude::*, tests::*}; + + type X = Tensor, TestDtype, TestDevice>; + type Y = [[TestDtype; 2]; 1]; + type M = MatMul, Const<2>, TestDtype, TestDevice>; + fn init() -> (TestDevice, X, Y, M) { + let dev: TestDevice = Default::default(); + let x: Tensor<_, TestDtype, _> = dev.tensor([[0.1, 0.2]]); + let y: [[TestDtype; 2]; 1] = [[7e2, 8e2]]; + let w: Tensor<_, TestDtype, _> = dev.tensor([[3., 4.], [5., 6.]]); + let mut m = dev.build_module::(MatMulConstConfig::<2, 2>::default()); + m.weight = w; + (dev, x, y, m) + } + + #[allow(clippy::too_many_arguments)] + fn check_against( + dev: &TestDevice, + x: X, + y: Y, + mut m: M, + mut opt: Prodigy, + expected_prediction: [[[f64; 2]; 1]; 10], + expected_grads: [[[f64; 2]; 2]; 10], + expected_updates: [[[f64; 2]; 2]; 10], + ) { + let mut grads = m.alloc_grads(); + for ((ey, eg), eu) in expected_prediction + .iter() + .zip(expected_grads) + .zip(expected_updates) + { + let prediction = m.forward_mut(x.trace(grads)); + assert_close_to_literal!(prediction, ey); + let loss = dfdx::losses::mse_loss(prediction, dev.tensor(y)); + grads = loss.backward(); + assert_close_to_literal!(grads.get(&m.weight), eg); + opt.update(&mut m, &grads).expect(""); + assert_close_to_literal!(m.weight, eu); + m.zero_grads(&mut grads); + } + } + + #[test] + fn test_default_prodigy_params() { + let (dev, x, y, m) = init(); + let opt = Prodigy::new(&m, Default::default()); + #[rustfmt::skip] + let expected_prediction: [[[f64; 2]; 1]; 10] = [ + [[1.100000023841858, 1.7000000476837158]], [[1.1000009775161743, 1.7000010013580322]], + [[1.1000022888183594, 1.7000024318695068]], [[1.1000046730041504, 1.7000048160552979]], + [[1.1000117063522339, 1.7000117301940918]], [[1.1000306606292725, 1.7000305652618408]], + [[1.1000787019729614, 1.7000787258148193]], [[1.1002024412155151, 1.700202465057373]], + [[1.1005204916000366, 1.7005205154418945]], [[1.101338505744934, 1.701338529586792]], + ]; + + #[rustfmt::skip] + let expected_grads: [[[f64; 2]; 2]; 10] = [ + [[-69.89000701904297, -139.78001403808594], [-79.83000183105469, -159.66000366210938]], + [[-69.89000701904297, -139.78001403808594], [-79.83000183105469, -159.66000366210938]], + [[-69.89000701904297, -139.78001403808594], [-79.83000183105469, -159.66000366210938]], + [[-69.89000701904297, -139.78001403808594], [-79.83000183105469, -159.66000366210938]], + [[-69.88999938964844, -139.77999877929688], [-79.83000183105469, -159.66000366210938]], + [[-69.88999938964844, -139.77999877929688], [-79.83000183105469, -159.66000366210938]], + [[-69.8899917602539, -139.7799835205078], [-79.82999420166016, -159.6599884033203]], + [[-69.88997650146484, -139.7799530029297], [-79.8299789428711, -159.6599578857422]], + [[-69.88994598388672, -139.77989196777344], [-79.82994842529297, -159.65989685058594]], + [[-69.8898696899414, -139.7797393798828], [-79.82986450195312, -159.65972900390625]], + ]; + + #[rustfmt::skip] + let expected_updates: [[[f64; 2]; 2]; 10] = [ + [[3.0000030994415283, 4.000003337860107], [5.000003337860107, 6.000003337860107]], + [[3.000007390975952, 4.000007629394531], [5.000007629394531, 6.000007629394531]], + [[3.0000154972076416, 4.000015735626221], [5.000015735626221, 6.000015735626221]], + [[3.0000391006469727, 4.000039100646973], [5.000039100646973, 6.000039100646973]], + [[3.0001020431518555, 4.0001020431518555], [5.0001020431518555, 6.0001020431518555]], + [[3.0002622604370117, 4.000262260437012], [5.000262260437012, 6.000262260437012]], + [[3.0006744861602783, 4.000674724578857], [5.000674724578857, 6.000674724578857]], + [[3.001734733581543, 4.001734733581543], [5.001734733581543, 6.001734733581543]], + [[3.0044617652893066, 4.004461765289307], [5.004461765289307, 6.004461765289307]], + [[3.011474609375, 4.011474609375], [5.011474609375, 6.011474609375]], + ]; + + #[rustfmt::skip] + check_against(&dev, x, y, m, opt, expected_prediction, expected_grads, expected_updates); + } + + #[test] + fn test_custom_prodigy_params() { + let (dev, x, y, m) = init(); + let opt = Prodigy::new( + &m, + ProdigyConfig { + lr: 2e1, + betas: [0.5, 0.25], + beta3: Some(0.4), + eps: 1e-8, + weight_decay: None, + use_bias_correction: true, + safeguard_warmup: true, + d0: 1e-5, + d_coef: 0.5, + growth_rate: 1.02, + }, + ); + + #[rustfmt::skip] + let expected_prediction: [[[f64; 2]; 1]; 10] = [ + [[1.100000023841858, 1.7000000476837158]], [[1.100059986114502, 1.7000598907470703]], + [[1.100119948387146, 1.700119972229004]], [[1.1073875427246094, 1.7073874473571777]], + [[1.1166749000549316, 1.716675043106079]], [[1.1270885467529297, 1.727088451385498]], + [[1.1381947994232178, 1.7381949424743652]], [[1.149769902229309, 1.749769926071167]], + [[1.1617008447647095, 1.7617008686065674]], [[1.173932433128357, 1.7739324569702148]], + ]; + + #[rustfmt::skip] + let expected_grads: [[[f64; 2]; 2]; 10] = [ + [ [-69.89000701904297, -139.78001403808594], [-79.83000183105469, -159.66000366210938], ], + [ [-69.88999938964844, -139.77999877929688], [-79.82999420166016, -159.6599884033203], ], + [ [-69.8899917602539, -139.7799835205078], [-79.82998657226562, -159.65997314453125], ], + [ [-69.88926696777344, -139.77853393554688], [-79.82926177978516, -159.6585235595703], ], + [ [-69.8883285522461, -139.7766571044922], [-79.82833099365234, -159.6566619873047], ], + [ [-69.88729095458984, -139.7745819091797], [-79.8272933959961, -159.6545867919922], ], + [ [-69.88618469238281, -139.77236938476562], [-79.82617950439453, -159.65235900878906], ], + [ [-69.88502502441406, -139.77005004882812], [-79.82502746582031, -159.65005493164062], ], + [ [-69.88383483886719, -139.76766967773438], [-79.8238296508789, -159.6476593017578], ], + [ [-69.88260650634766, -139.7652130126953], [-79.8226089477539, -159.6452178955078], ], + ]; + + #[rustfmt::skip] + let expected_updates: [[[f64; 2]; 2]; 10] = [ + [ [3.000200033187866, 4.000199794769287], [5.000199794769287, 6.000199794769287], ], + [ [3.0004000663757324, 4.000399589538574], [5.000399589538574, 6.000399589538574], ], + [ [3.024625062942505, 4.024624824523926], [5.024624824523926, 6.024624824523926], ], + [ [3.0555830001831055, 4.0555830001831055], [5.0555830001831055, 6.0555830001831055], ], + [ [3.0902950763702393, 4.09029483795166], [5.09029483795166, 6.09029483795166], ], + [ [3.1273159980773926, 4.127315998077393], [5.127315998077393, 6.127315998077393], ], + [ [3.1658997535705566, 4.165899753570557], [5.165899753570557, 6.165899753570557], ], + [ [3.205669403076172, 4.205669403076172], [5.205669403076172, 6.205669403076172], ], + [ [3.246441602706909, 4.24644136428833], [5.24644136428833, 6.24644136428833], ], + [ [3.2881321907043457, 4.288132190704346], [5.288132190704346, 6.288132190704346], ], + ]; + + #[rustfmt::skip] + check_against(&dev, x, y, m, opt, expected_prediction, expected_grads, expected_updates); + } + + #[test] + fn test_prodigy_l2_decay() { + let (dev, x, y, m) = init(); + let opt = Prodigy::new( + &m, + ProdigyConfig { + betas: [0.5, 0.25], + beta3: Some(0.4), + weight_decay: Some(WeightDecay::L2(1.0)), + ..Default::default() + }, + ); + + #[rustfmt::skip] + let expected_prediction: [[[f64; 2]; 1]; 10] = [ + [[1.100000023841858, 1.7000000476837158]], [[1.1000001430511475, 1.700000286102295]], + [[1.1000003814697266, 1.700000524520874]], [[1.1000007390975952, 1.7000007629394531]], + [[1.1000009775161743, 1.7000010013580322]], [[1.1000014543533325, 1.7000014781951904]], + [[1.1000021696090698, 1.7000021934509277]], [[1.1000032424926758, 1.7000033855438232]], + [[1.10000479221344, 1.7000048160552979]], [[1.1000072956085205, 1.700007438659668]], + ]; + + #[rustfmt::skip] + let expected_grads: [[[f64; 2]; 2]; 10] = [ + [ [-69.89000701904297, -139.78001403808594], [-79.83000183105469, -159.66000366210938], ], + [ [-69.89000701904297, -139.78001403808594], [-79.83000183105469, -159.66000366210938], ], + [ [-69.89000701904297, -139.78001403808594], [-79.83000183105469, -159.66000366210938], ], + [ [-69.89000701904297, -139.78001403808594], [-79.83000183105469, -159.66000366210938], ], + [ [-69.89000701904297, -139.78001403808594], [-79.83000183105469, -159.66000366210938], ], + [ [-69.89000701904297, -139.78001403808594], [-79.83000183105469, -159.66000366210938], ], + [ [-69.89000701904297, -139.78001403808594], [-79.83000183105469, -159.66000366210938], ], + [ [-69.89000701904297, -139.78001403808594], [-79.83000183105469, -159.66000366210938], ], + [ [-69.89000701904297, -139.78001403808594], [-79.83000183105469, -159.66000366210938], ], + [ [-69.88999938964844, -139.77999877929688], [-79.83000183105469, -159.66000366210938], ], + ]; + #[rustfmt::skip] + let expected_updates: [[[f64; 2]; 2]; 10] = [ + [ [3.000000476837158, 4.000000476837158], [5.000000476837158, 6.000000476837158], ], + [ [3.0000011920928955, 4.000001430511475], [5.000001430511475, 6.000001430511475], ], + [ [3.000002145767212, 4.000002384185791], [5.000002384185791, 6.000002384185791], ], + [ [3.0000030994415283, 4.000003337860107], [5.000003337860107, 6.000003337860107], ], + [ [3.000004529953003, 4.000004768371582], [5.000004768371582, 6.000004768371582], ], + [ [3.000006914138794, 4.000007152557373], [5.000007152557373, 6.000007152557373], ], + [ [3.0000104904174805, 4.000010967254639], [5.000010967254639, 6.000010967254639], ], + [ [3.0000159740448, 4.000016212463379], [5.000016212463379, 6.000016212463379], ], + [ [3.0000243186950684, 4.000024318695068], [5.000024318695068, 6.000024318695068], ], + [ [3.0000367164611816, 4.000036716461182], [5.000036716461182, 6.000036716461182], ], + ]; + #[rustfmt::skip] + check_against(&dev, x, y, m, opt, expected_prediction, expected_grads, expected_updates); + } + + #[test] + fn test_prodigy_decoupled_decay() { + let (dev, x, y, m) = init(); + let opt = Prodigy::new( + &m, + ProdigyConfig { + betas: [0.5, 0.25], + beta3: Some(0.4), + weight_decay: Some(WeightDecay::Decoupled(1e3)), + ..Default::default() + }, + ); + + #[rustfmt::skip] + let expected_prediction: [[[f64; 2]; 1]; 10] = [ + [[1.100000023841858, 1.7000000476837158]], [[1.0989001989364624, 1.6983001232147217]], + [[1.0978014469146729, 1.69660222530365]], [[1.0967040061950684, 1.6949058771133423]], + [[1.0956075191497803, 1.693211317062378]], [[1.0945122241973877, 1.6915183067321777]], + [[1.093418002128601, 1.6898270845413208]], [[1.0923248529434204, 1.6881375312805176]], + [[1.0912327766418457, 1.686449646949768]], [[1.0901418924331665, 1.6847634315490723]], + ]; + + #[rustfmt::skip] + let expected_grads: [[[f64; 2]; 2]; 10] = [ + [ [-69.89000701904297, -139.78001403808594], [-79.83000183105469, -159.66000366210938], ], + [ [-69.8901138305664, -139.7802276611328], [-79.83016967773438, -159.66033935546875], ], + [ [-69.89022064208984, -139.7804412841797], [-79.8303451538086, -159.6606903076172], ], + [ [-69.89033508300781, -139.78067016601562], [-79.83051300048828, -159.66102600097656], ], + [ [-69.89044189453125, -139.7808837890625], [-79.83068084716797, -159.66136169433594], ], + [ [-69.89055633544922, -139.78111267089844], [-79.83084869384766, -159.6616973876953], ], + [ [-69.89065551757812, -139.78131103515625], [-79.83101654052734, -159.6620330810547], ], + [ [-69.8907699584961, -139.7815399169922], [-79.83119201660156, -159.66238403320312], ], + [ [-69.89087677001953, -139.78175354003906], [-79.83135223388672, -159.66270446777344], ], + [ [-69.89098358154297, -139.78196716308594], [-79.83152770996094, -159.66305541992188], ], + ]; + #[rustfmt::skip] + let expected_updates: [[[f64; 2]; 2]; 10] = [ + [ [2.9970004558563232, 3.9960005283355713], [4.99500036239624, 5.994000434875488], ], + [ [2.994004249572754, 3.9920053482055664], [4.990006446838379, 5.988007545471191], ], + [ [2.991011142730713, 3.9880142211914062], [4.9850172996521, 5.982020378112793], ], + [ [2.9880211353302, 3.984027147293091], [4.9800333976745605, 5.976039409637451], ], + [ [2.9850339889526367, 3.98004412651062], [4.9750542640686035, 5.970064163208008], ], + [ [2.9820499420166016, 3.976064920425415], [4.970080375671387, 5.964095115661621], ], + [ [2.9790687561035156, 3.9720897674560547], [4.965111255645752, 5.958131790161133], ], + [ [2.976090669631958, 3.968118667602539], [4.960146903991699, 5.952174663543701], ], + [ [2.9731154441833496, 3.964151620864868], [4.955187797546387, 5.946223258972168], ], + [ [2.9701433181762695, 3.960188388824463], [4.950233459472656, 5.940278053283691], ], + ]; + #[rustfmt::skip] + check_against(&dev, x, y, m, opt, expected_prediction, expected_grads, expected_updates); + } + + #[test] + fn test_unused_tensors() { + let dev: TestDevice = Default::default(); + let mut t: Tensor, TestDtype, _> = dev.sample_normal(); + let mut opt = Prodigy::new(&t, Default::default()); + opt.update(&mut t, &Gradients::leaky()).expect_err(""); + } +}