Skip to content

Commit

Permalink
add RMS normalization
Browse files Browse the repository at this point in the history
- Add the try_normalize_rms related functions.
- Add the `LayerRMSNorm1D` module.
  • Loading branch information
swfsql committed Mar 1, 2024
1 parent 1175903 commit d237ea6
Show file tree
Hide file tree
Showing 4 changed files with 309 additions and 0 deletions.
2 changes: 2 additions & 0 deletions dfdx-core/src/tensor_ops/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ mod mul;
mod nans_to;
mod negate;
mod normalize;
mod normalize_rms;
pub(super) mod optim;
mod permute_to;
mod pow;
Expand Down Expand Up @@ -251,6 +252,7 @@ pub use mul::{mul, TryMul};
pub use nans_to::nans_to;
pub use negate::negate;
pub use normalize::normalize;
pub use normalize_rms::normalize_rms;
pub use optim::*;
pub use permute_to::PermuteTo;
pub use pow::{powf, powi};
Expand Down
136 changes: 136 additions & 0 deletions dfdx-core/src/tensor_ops/normalize_rms.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
use crate::{
shapes::{Axes, Dtype, ReduceShape, Shape},
tensor::{Error, Tape, Tensor},
};

use super::{BroadcastTo, Device, MeanTo, TryAdd, TryMul};

/// Normalizes `t` to have stddev `1.0` along `Ax`. `epsilon` is used during stddev.
/// Computes `t / (t.square().mean() + epsilon).sqrt()`.
///
/// Normalizing a single axis:
/// ```rust
/// # use dfdx_core::prelude::*;
/// # let dev: Cpu = Default::default();
/// let t: Tensor<Rank2<2, 3>, f32, _> = dev.zeros();
/// let _ = t.normalize_rms::<Axis<1>>(1e-5);
/// ```
pub fn normalize_rms<
Ax: Axes,
S: Shape + ReduceShape<Ax>,
E: Dtype,
D: Device<E>,
T: Tape<E, D>,
>(
t: Tensor<S, E, D, T>,
epsilon: impl Into<f64>,
) -> Tensor<S, E, D, T> {
t.normalize_rms::<Ax>(epsilon)
}

impl<S: Shape, E: Dtype, D: Device<E>, T: Tape<E, D>> Tensor<S, E, D, T> {
/// See [normalize_rms]
pub fn normalize_rms<Ax: Axes>(self, epsilon: impl Into<f64>) -> Self
where
S: ReduceShape<Ax>,
{
self.try_normalize_rms::<Ax>(epsilon).unwrap()
}

/// See [normalize_rms]
pub fn try_normalize_rms<Ax: Axes>(self, epsilon: impl Into<f64>) -> Result<Self, Error>
where
S: ReduceShape<Ax>,
{
let shape = self.shape;
let sq = self.retaped::<T>().try_square()?;
let sq_mean = sq.try_mean::<_, Ax>()?;
let rsqrt = sq_mean
.try_add(epsilon)?
.try_sqrt()?
.try_recip()?
.try_broadcast_like(&shape)?;
self.try_mul(rsqrt)
}
}

#[cfg(test)]
mod tests {
use crate::tests::*;
use crate::{shapes::*, tensor::*, tensor_ops::*};

#[test]
fn test_1d_normalize_rms_axis_last() {
let dev: TestDevice = Default::default();
let a = dev.tensor([-2.0, 0.0, 5.0]).to_dtype::<TestDtype>();
let r = a.leaky_trace().normalize_rms(1e-5);
assert_close_to_literal!(&r, [-0.64326715, 0.0, 1.6081679]);
// NOTE: .exp() so we can make sure normalize is using result grad properly
let g = r.exp().mean().backward();
assert_close_to_literal!(&g.get(&a), [0.23318729, 0.107211195, 0.09327549]);
}

#[test]
fn test_2d_normalize_rms_axis_last() {
let dev: TestDevice = Default::default();
let a = dev
.tensor([[-2.0, 0.0, 5.0], [1.0, 2.0, 3.0]])
.to_dtype::<TestDtype>();
let r = a.leaky_trace().normalize_rms::<Axis<1>>(1e-5);
assert_close_to_literal!(
r,
[
[-0.64326715, 0.0, 1.6081679],
[0.46290955, 0.9258191, 1.3887286]
]
);
let g = r.exp().mean().backward();
assert_close_to_literal!(
g.get(&a),
[
[0.116593644, 0.053605597, 0.046637744],
[0.019706108, -0.011002079, 0.0007670224]
]
);
}

#[test]
fn test_2d_normalize_rms_axis_first() {
let dev: TestDevice = Default::default();
let a = dev
.tensor([[-2.0, 0.0], [1.0, 2.0], [4.0, 5.0]])
.to_dtype::<TestDtype>();
let r = a.leaky_trace().normalize_rms::<Axis<0>>(1e-5);
assert_close_to_literal!(
r,
[
[-0.7559284, 0.0],
[0.3779642, 0.64326715],
[1.5118568, 1.6081679]
]
);
let g = r.exp().mean().backward();
assert_close_to_literal!(
g.get(&a),
[
[0.14153406, 0.053605597],
[0.03595103, -0.0043795705],
[0.061779693, 0.0017521679]
]
);
}

#[test]
fn test_3d_normalize_rms_axis_last() {
let dev: TestDevice = Default::default();
let a: Tensor<Rank3<4, 2, 3>, TestDtype, _> = dev.ones();
let r = a.leaky_trace().normalize_rms::<Axis<2>>(1e-5);
assert_close_to_literal!(r, [[[1.0; 3]; 2]; 4], 1e-5);
let g = r.exp().mean().backward();
assert_close_to_literal!(g.get(&a), [[[0.0; 3]; 2]; 4], 1e-5);
}
}

// Implementation references:
// - https://github.com/johnma2006/mamba-minimal/blob/03de542a36d873f6e6c4057ad687278cc6ae944d/model.py#L328
// - https://github.com/kroggen/mamba.c/blob/7387f49e352f86a0c22041c0f66fd2a40b58a207/mamba.c#L222
169 changes: 169 additions & 0 deletions dfdx/src/nn/layers/layer_rms_norm1d.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
use crate::prelude::*;

/// Implements RMS layer normalization as described in [Root Mean Square Layer Normalization](https://arxiv.org/abs/1910.07467).
///
/// This calls [normalize_rms()] on the last axis of the input to normalize to unit std dev, and then does an element-wise
/// affine transform using learnable parameters.
///
/// Epsilon is passed to [normalize_rms()] and added to the variance to ensure big enough numbers. It defaults to `1e-5`.
///
/// Generics:
/// - `M` The size of the affine transform tensors.
///
/// # Examples
/// ```rust
/// # use dfdx::prelude::*;
/// # use dfdx::*;
/// # let dev: Cpu = Default::default();
/// type Model = LayerRMSNorm1DConstConfig<5>;
/// let model = dev.build_module::<f32>(Model::default());
/// let _: Tensor<Rank1<5>, f32, _> = model.forward(dev.zeros::<Rank1<5>>());
/// ```
#[derive(Default, Clone, Copy, Debug)]
#[repr(transparent)]
pub struct LayerRMSNorm1DConfig<M: Dim>(pub M);

/// Compile time sugar alias around [LayerRMSNorm1DConfig]
pub type LayerRMSNorm1DConstConfig<const M: usize> = LayerRMSNorm1DConfig<Const<M>>;

impl<M: Dim, E: Dtype, D: Device<E>> BuildOnDevice<E, D> for LayerRMSNorm1DConfig<M> {
type Built = LayerRMSNorm1D<M, E, D>;
fn try_build_on_device(&self, device: &D) -> Result<Self::Built, crate::tensor::Error> {
Ok(LayerRMSNorm1D {
gamma: device.try_ones_like(&(self.0,))?,
beta: device.try_zeros_like(&(self.0,))?,
epsilon: 1e-5,
})
}
}

/// See [LayerRMSNorm1DConfig]
#[derive(Clone, Debug, UpdateParams, ZeroGrads, WithGrads)]
#[cfg_attr(feature = "safetensors", derive(SaveSafeTensors, LoadSafeTensors))]
pub struct LayerRMSNorm1D<M: Dim, Elem: Dtype, Dev: Device<Elem>> {
#[param]
#[cfg_attr(feature = "safetensors", serialize)]
pub gamma: Tensor<(M,), Elem, Dev>,
#[param]
#[cfg_attr(feature = "safetensors", serialize)]
pub beta: Tensor<(M,), Elem, Dev>,
#[cfg_attr(feature = "safetensors", serialize)]
pub epsilon: f64,
}

impl<M: Dim, E: Dtype, D: Device<E>> ResetParams<E, D> for LayerRMSNorm1D<M, E, D> {
fn try_reset_params(&mut self) -> Result<(), crate::tensor::Error> {
self.gamma.try_fill_with_ones()?;
self.beta.try_fill_with_zeros()?;
Ok(())
}
}

impl<M: Dim, E: Dtype, D: Device<E>, T: Tape<E, D>> Module<Tensor<(M,), E, D, T>>
for LayerRMSNorm1D<M, E, D>
{
type Output = Tensor<(M,), E, D, T>;
fn try_forward(&self, x: Tensor<(M,), E, D, T>) -> Result<Self::Output, Error> {
let x = x.try_normalize_rms::<Axis<0>>(self.epsilon)?;
let x = self.gamma.retaped::<T>().try_mul(x)?;
self.beta.retaped::<T>().try_add(x)
}
}

impl<Batch: Dim, M: Dim, E: Dtype, D: Device<E>, T: Tape<E, D>> Module<Tensor<(Batch, M), E, D, T>>
for LayerRMSNorm1D<M, E, D>
{
type Output = Tensor<(Batch, M), E, D, T>;
fn try_forward(&self, x: Tensor<(Batch, M), E, D, T>) -> Result<Self::Output, Error> {
let x = x.try_normalize_rms::<Axis<1>>(self.epsilon)?;
let x = self.gamma.retaped::<T>().broadcast_like(&x).try_mul(x)?;
self.beta.retaped::<T>().broadcast_like(&x).try_add(x)
}
}

impl<Batch: Dim, Seq: Dim, M: Dim, E: Dtype, D: Device<E>, T: Tape<E, D>>
Module<Tensor<(Batch, Seq, M), E, D, T>> for LayerRMSNorm1D<M, E, D>
{
type Output = Tensor<(Batch, Seq, M), E, D, T>;
fn try_forward(&self, x: Tensor<(Batch, Seq, M), E, D, T>) -> Result<Self::Output, Error> {
let x = x.try_normalize_rms::<Axis<2>>(self.epsilon)?;
let x = self.gamma.retaped::<T>().broadcast_like(&x).try_mul(x)?;
self.beta.retaped::<T>().broadcast_like(&x).try_add(x)
}
}

#[cfg(test)]
mod tests {
use super::*;
use crate::tests::*;

#[test]
fn test_layer_rms_norm_reset() {
let dev: TestDevice = Default::default();

let mut m = dev.build_module::<TestDtype>(<LayerRMSNorm1DConstConfig<5>>::default());
assert_close_to_literal!(m.gamma, [1.0; 5]);
assert_close_to_literal!(m.beta, [0.0; 5]);

m.gamma = dev.sample_normal();
m.beta = dev.sample_normal();

assert_ne!(m.gamma.array(), [TestDtype::ONE; 5]);
assert_ne!(m.beta.array(), [TestDtype::default(); 5]);

m.reset_params();

assert_close_to_literal!(m.gamma, [1.0; 5]);
assert_close_to_literal!(m.beta, [0.0; 5]);
}

#[test]
fn test_layer_rms_norm_1d_forward() {
let dev: TestDevice = Default::default();
let mut m = dev.build_module::<TestDtype>(<LayerRMSNorm1DConstConfig<5>>::default());
let x = dev.sample_normal::<Rank1<5>>();
let r = m.forward_mut(x.leaky_trace());
assert_close_to_literal!(
r,
[0.53631353, 0.6458002, -1.8330059, 0.12289862, -0.9593052]
);
let g = r.mean().backward();
assert_close_to_literal!(
g.get(&m.gamma),
[0.10726271, 0.12916003, -0.3666012, 0.024579724, -0.19186105]
);
assert_close_to_literal!(g.get(&m.beta), [0.2; 5]);
}

#[test]
fn test_layer_rms_norm_2d_forward() {
let dev: TestDevice = Default::default();
let m = dev.build_module::<TestDtype>(<LayerRMSNorm1DConstConfig<5>>::default());
let x = dev.sample_normal::<Rank2<3, 5>>();
let r = m.forward(x.leaky_trace());
assert_close_to_literal!(
r,
[
[0.53631353, 0.6458002, -1.8330059, 0.12289862, -0.9593052],
[1.0418473, -1.199064, 0.49583954, 0.5000605, 1.4074267],
[0.90727454, -1.6644237, -0.5176145, 1.0127299, -0.33612955]
]
);
let g = r.mean().backward();
assert_close_to_literal!(
g.get(&m.gamma),
[
0.16569571,
-0.14784585,
-0.123652056,
0.10904594,
0.0074661337
]
);
assert_close_to_literal!(g.get(&m.beta), [0.2; 5]);
}
}

// Implementation references:
// - https://github.com/johnma2006/mamba-minimal/blob/03de542a36d873f6e6c4057ad687278cc6ae944d/model.py#L328
// - https://github.com/kroggen/mamba.c/blob/7387f49e352f86a0c22041c0f66fd2a40b58a207/mamba.c#L222
2 changes: 2 additions & 0 deletions dfdx/src/nn/layers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ mod gelu;
mod generalized_add;
mod generalized_mul;
mod layer_norm1d;
mod layer_rms_norm1d;
mod leaky_relu;
mod linear;
mod ln;
Expand Down Expand Up @@ -73,6 +74,7 @@ pub use gelu::{AccurateGeLU, FastGeLU};
pub use generalized_add::GeneralizedAdd;
pub use generalized_mul::GeneralizedMul;
pub use layer_norm1d::{LayerNorm1D, LayerNorm1DConfig, LayerNorm1DConstConfig};
pub use layer_rms_norm1d::{LayerRMSNorm1D, LayerRMSNorm1DConfig, LayerRMSNorm1DConstConfig};
pub use leaky_relu::LeakyReLU;
pub use linear::{Linear, LinearConfig, LinearConstConfig};
pub use ln::Ln;
Expand Down

0 comments on commit d237ea6

Please sign in to comment.