From d237ea64c8ed02a6ee08fdc6af01c096e366b66f Mon Sep 17 00:00:00 2001 From: Thiago Machado Date: Wed, 31 Jan 2024 01:54:19 -0500 Subject: [PATCH] add RMS normalization - Add the try_normalize_rms related functions. - Add the `LayerRMSNorm1D` module. --- dfdx-core/src/tensor_ops/mod.rs | 2 + dfdx-core/src/tensor_ops/normalize_rms.rs | 136 +++++++++++++++++ dfdx/src/nn/layers/layer_rms_norm1d.rs | 169 ++++++++++++++++++++++ dfdx/src/nn/layers/mod.rs | 2 + 4 files changed, 309 insertions(+) create mode 100644 dfdx-core/src/tensor_ops/normalize_rms.rs create mode 100644 dfdx/src/nn/layers/layer_rms_norm1d.rs diff --git a/dfdx-core/src/tensor_ops/mod.rs b/dfdx-core/src/tensor_ops/mod.rs index 453457f4..a649196c 100644 --- a/dfdx-core/src/tensor_ops/mod.rs +++ b/dfdx-core/src/tensor_ops/mod.rs @@ -184,6 +184,7 @@ mod mul; mod nans_to; mod negate; mod normalize; +mod normalize_rms; pub(super) mod optim; mod permute_to; mod pow; @@ -251,6 +252,7 @@ pub use mul::{mul, TryMul}; pub use nans_to::nans_to; pub use negate::negate; pub use normalize::normalize; +pub use normalize_rms::normalize_rms; pub use optim::*; pub use permute_to::PermuteTo; pub use pow::{powf, powi}; diff --git a/dfdx-core/src/tensor_ops/normalize_rms.rs b/dfdx-core/src/tensor_ops/normalize_rms.rs new file mode 100644 index 00000000..eb70302a --- /dev/null +++ b/dfdx-core/src/tensor_ops/normalize_rms.rs @@ -0,0 +1,136 @@ +use crate::{ + shapes::{Axes, Dtype, ReduceShape, Shape}, + tensor::{Error, Tape, Tensor}, +}; + +use super::{BroadcastTo, Device, MeanTo, TryAdd, TryMul}; + +/// Normalizes `t` to have stddev `1.0` along `Ax`. `epsilon` is used during stddev. +/// Computes `t / (t.square().mean() + epsilon).sqrt()`. +/// +/// Normalizing a single axis: +/// ```rust +/// # use dfdx_core::prelude::*; +/// # let dev: Cpu = Default::default(); +/// let t: Tensor, f32, _> = dev.zeros(); +/// let _ = t.normalize_rms::>(1e-5); +/// ``` +pub fn normalize_rms< + Ax: Axes, + S: Shape + ReduceShape, + E: Dtype, + D: Device, + T: Tape, +>( + t: Tensor, + epsilon: impl Into, +) -> Tensor { + t.normalize_rms::(epsilon) +} + +impl, T: Tape> Tensor { + /// See [normalize_rms] + pub fn normalize_rms(self, epsilon: impl Into) -> Self + where + S: ReduceShape, + { + self.try_normalize_rms::(epsilon).unwrap() + } + + /// See [normalize_rms] + pub fn try_normalize_rms(self, epsilon: impl Into) -> Result + where + S: ReduceShape, + { + let shape = self.shape; + let sq = self.retaped::().try_square()?; + let sq_mean = sq.try_mean::<_, Ax>()?; + let rsqrt = sq_mean + .try_add(epsilon)? + .try_sqrt()? + .try_recip()? + .try_broadcast_like(&shape)?; + self.try_mul(rsqrt) + } +} + +#[cfg(test)] +mod tests { + use crate::tests::*; + use crate::{shapes::*, tensor::*, tensor_ops::*}; + + #[test] + fn test_1d_normalize_rms_axis_last() { + let dev: TestDevice = Default::default(); + let a = dev.tensor([-2.0, 0.0, 5.0]).to_dtype::(); + let r = a.leaky_trace().normalize_rms(1e-5); + assert_close_to_literal!(&r, [-0.64326715, 0.0, 1.6081679]); + // NOTE: .exp() so we can make sure normalize is using result grad properly + let g = r.exp().mean().backward(); + assert_close_to_literal!(&g.get(&a), [0.23318729, 0.107211195, 0.09327549]); + } + + #[test] + fn test_2d_normalize_rms_axis_last() { + let dev: TestDevice = Default::default(); + let a = dev + .tensor([[-2.0, 0.0, 5.0], [1.0, 2.0, 3.0]]) + .to_dtype::(); + let r = a.leaky_trace().normalize_rms::>(1e-5); + assert_close_to_literal!( + r, + [ + [-0.64326715, 0.0, 1.6081679], + [0.46290955, 0.9258191, 1.3887286] + ] + ); + let g = r.exp().mean().backward(); + assert_close_to_literal!( + g.get(&a), + [ + [0.116593644, 0.053605597, 0.046637744], + [0.019706108, -0.011002079, 0.0007670224] + ] + ); + } + + #[test] + fn test_2d_normalize_rms_axis_first() { + let dev: TestDevice = Default::default(); + let a = dev + .tensor([[-2.0, 0.0], [1.0, 2.0], [4.0, 5.0]]) + .to_dtype::(); + let r = a.leaky_trace().normalize_rms::>(1e-5); + assert_close_to_literal!( + r, + [ + [-0.7559284, 0.0], + [0.3779642, 0.64326715], + [1.5118568, 1.6081679] + ] + ); + let g = r.exp().mean().backward(); + assert_close_to_literal!( + g.get(&a), + [ + [0.14153406, 0.053605597], + [0.03595103, -0.0043795705], + [0.061779693, 0.0017521679] + ] + ); + } + + #[test] + fn test_3d_normalize_rms_axis_last() { + let dev: TestDevice = Default::default(); + let a: Tensor, TestDtype, _> = dev.ones(); + let r = a.leaky_trace().normalize_rms::>(1e-5); + assert_close_to_literal!(r, [[[1.0; 3]; 2]; 4], 1e-5); + let g = r.exp().mean().backward(); + assert_close_to_literal!(g.get(&a), [[[0.0; 3]; 2]; 4], 1e-5); + } +} + +// Implementation references: +// - https://github.com/johnma2006/mamba-minimal/blob/03de542a36d873f6e6c4057ad687278cc6ae944d/model.py#L328 +// - https://github.com/kroggen/mamba.c/blob/7387f49e352f86a0c22041c0f66fd2a40b58a207/mamba.c#L222 diff --git a/dfdx/src/nn/layers/layer_rms_norm1d.rs b/dfdx/src/nn/layers/layer_rms_norm1d.rs new file mode 100644 index 00000000..17143aef --- /dev/null +++ b/dfdx/src/nn/layers/layer_rms_norm1d.rs @@ -0,0 +1,169 @@ +use crate::prelude::*; + +/// Implements RMS layer normalization as described in [Root Mean Square Layer Normalization](https://arxiv.org/abs/1910.07467). +/// +/// This calls [normalize_rms()] on the last axis of the input to normalize to unit std dev, and then does an element-wise +/// affine transform using learnable parameters. +/// +/// Epsilon is passed to [normalize_rms()] and added to the variance to ensure big enough numbers. It defaults to `1e-5`. +/// +/// Generics: +/// - `M` The size of the affine transform tensors. +/// +/// # Examples +/// ```rust +/// # use dfdx::prelude::*; +/// # use dfdx::*; +/// # let dev: Cpu = Default::default(); +/// type Model = LayerRMSNorm1DConstConfig<5>; +/// let model = dev.build_module::(Model::default()); +/// let _: Tensor, f32, _> = model.forward(dev.zeros::>()); +/// ``` +#[derive(Default, Clone, Copy, Debug)] +#[repr(transparent)] +pub struct LayerRMSNorm1DConfig(pub M); + +/// Compile time sugar alias around [LayerRMSNorm1DConfig] +pub type LayerRMSNorm1DConstConfig = LayerRMSNorm1DConfig>; + +impl> BuildOnDevice for LayerRMSNorm1DConfig { + type Built = LayerRMSNorm1D; + fn try_build_on_device(&self, device: &D) -> Result { + Ok(LayerRMSNorm1D { + gamma: device.try_ones_like(&(self.0,))?, + beta: device.try_zeros_like(&(self.0,))?, + epsilon: 1e-5, + }) + } +} + +/// See [LayerRMSNorm1DConfig] +#[derive(Clone, Debug, UpdateParams, ZeroGrads, WithGrads)] +#[cfg_attr(feature = "safetensors", derive(SaveSafeTensors, LoadSafeTensors))] +pub struct LayerRMSNorm1D> { + #[param] + #[cfg_attr(feature = "safetensors", serialize)] + pub gamma: Tensor<(M,), Elem, Dev>, + #[param] + #[cfg_attr(feature = "safetensors", serialize)] + pub beta: Tensor<(M,), Elem, Dev>, + #[cfg_attr(feature = "safetensors", serialize)] + pub epsilon: f64, +} + +impl> ResetParams for LayerRMSNorm1D { + fn try_reset_params(&mut self) -> Result<(), crate::tensor::Error> { + self.gamma.try_fill_with_ones()?; + self.beta.try_fill_with_zeros()?; + Ok(()) + } +} + +impl, T: Tape> Module> + for LayerRMSNorm1D +{ + type Output = Tensor<(M,), E, D, T>; + fn try_forward(&self, x: Tensor<(M,), E, D, T>) -> Result { + let x = x.try_normalize_rms::>(self.epsilon)?; + let x = self.gamma.retaped::().try_mul(x)?; + self.beta.retaped::().try_add(x) + } +} + +impl, T: Tape> Module> + for LayerRMSNorm1D +{ + type Output = Tensor<(Batch, M), E, D, T>; + fn try_forward(&self, x: Tensor<(Batch, M), E, D, T>) -> Result { + let x = x.try_normalize_rms::>(self.epsilon)?; + let x = self.gamma.retaped::().broadcast_like(&x).try_mul(x)?; + self.beta.retaped::().broadcast_like(&x).try_add(x) + } +} + +impl, T: Tape> + Module> for LayerRMSNorm1D +{ + type Output = Tensor<(Batch, Seq, M), E, D, T>; + fn try_forward(&self, x: Tensor<(Batch, Seq, M), E, D, T>) -> Result { + let x = x.try_normalize_rms::>(self.epsilon)?; + let x = self.gamma.retaped::().broadcast_like(&x).try_mul(x)?; + self.beta.retaped::().broadcast_like(&x).try_add(x) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::tests::*; + + #[test] + fn test_layer_rms_norm_reset() { + let dev: TestDevice = Default::default(); + + let mut m = dev.build_module::(>::default()); + assert_close_to_literal!(m.gamma, [1.0; 5]); + assert_close_to_literal!(m.beta, [0.0; 5]); + + m.gamma = dev.sample_normal(); + m.beta = dev.sample_normal(); + + assert_ne!(m.gamma.array(), [TestDtype::ONE; 5]); + assert_ne!(m.beta.array(), [TestDtype::default(); 5]); + + m.reset_params(); + + assert_close_to_literal!(m.gamma, [1.0; 5]); + assert_close_to_literal!(m.beta, [0.0; 5]); + } + + #[test] + fn test_layer_rms_norm_1d_forward() { + let dev: TestDevice = Default::default(); + let mut m = dev.build_module::(>::default()); + let x = dev.sample_normal::>(); + let r = m.forward_mut(x.leaky_trace()); + assert_close_to_literal!( + r, + [0.53631353, 0.6458002, -1.8330059, 0.12289862, -0.9593052] + ); + let g = r.mean().backward(); + assert_close_to_literal!( + g.get(&m.gamma), + [0.10726271, 0.12916003, -0.3666012, 0.024579724, -0.19186105] + ); + assert_close_to_literal!(g.get(&m.beta), [0.2; 5]); + } + + #[test] + fn test_layer_rms_norm_2d_forward() { + let dev: TestDevice = Default::default(); + let m = dev.build_module::(>::default()); + let x = dev.sample_normal::>(); + let r = m.forward(x.leaky_trace()); + assert_close_to_literal!( + r, + [ + [0.53631353, 0.6458002, -1.8330059, 0.12289862, -0.9593052], + [1.0418473, -1.199064, 0.49583954, 0.5000605, 1.4074267], + [0.90727454, -1.6644237, -0.5176145, 1.0127299, -0.33612955] + ] + ); + let g = r.mean().backward(); + assert_close_to_literal!( + g.get(&m.gamma), + [ + 0.16569571, + -0.14784585, + -0.123652056, + 0.10904594, + 0.0074661337 + ] + ); + assert_close_to_literal!(g.get(&m.beta), [0.2; 5]); + } +} + +// Implementation references: +// - https://github.com/johnma2006/mamba-minimal/blob/03de542a36d873f6e6c4057ad687278cc6ae944d/model.py#L328 +// - https://github.com/kroggen/mamba.c/blob/7387f49e352f86a0c22041c0f66fd2a40b58a207/mamba.c#L222 diff --git a/dfdx/src/nn/layers/mod.rs b/dfdx/src/nn/layers/mod.rs index 828b1e97..062b9f08 100644 --- a/dfdx/src/nn/layers/mod.rs +++ b/dfdx/src/nn/layers/mod.rs @@ -20,6 +20,7 @@ mod gelu; mod generalized_add; mod generalized_mul; mod layer_norm1d; +mod layer_rms_norm1d; mod leaky_relu; mod linear; mod ln; @@ -73,6 +74,7 @@ pub use gelu::{AccurateGeLU, FastGeLU}; pub use generalized_add::GeneralizedAdd; pub use generalized_mul::GeneralizedMul; pub use layer_norm1d::{LayerNorm1D, LayerNorm1DConfig, LayerNorm1DConstConfig}; +pub use layer_rms_norm1d::{LayerRMSNorm1D, LayerRMSNorm1DConfig, LayerRMSNorm1DConstConfig}; pub use leaky_relu::LeakyReLU; pub use linear::{Linear, LinearConfig, LinearConstConfig}; pub use ln::Ln;