From 8fbf249e49ab1943df071c4067ffe18b779b3695 Mon Sep 17 00:00:00 2001 From: Jan Cristina Date: Fri, 14 Jul 2023 14:11:51 +0200 Subject: [PATCH] accurate-gelu - more documentation, fix tests --- src/nn/activations.rs | 2 +- src/tensor_ops/accurate_gelu/mod.rs | 22 ++++++++++++++++++---- src/tensor_ops/gelu/mod.rs | 2 +- 3 files changed, 20 insertions(+), 6 deletions(-) diff --git a/src/nn/activations.rs b/src/nn/activations.rs index 17d94e012..dd5f7b369 100644 --- a/src/nn/activations.rs +++ b/src/nn/activations.rs @@ -26,7 +26,7 @@ macro_rules! activation_impls { activation_impls!(ReLU, try_relu, #[doc="Calls [relu()]."]); activation_impls!(GeLU, try_gelu, #[doc="Calls [gelu()]."]); -activation_impls!(AccurateGeLU, try_accurate_gelu, #[doc="Calls [accurate_gelu()]."]); +activation_impls!(AccurateGeLU, try_accurate_gelu, #[doc="Calls [accurate_gelu()]. The GeLU is defined as x * Phi(x) where Phi is the cumulative distribution function of a standard Normal Distribution. It is often implemented with a fast approximation using tanh (see [GeLU])"]); activation_impls!(Sin, try_sin, #[doc="Calls [sin()]."]); activation_impls!(Cos, try_cos, #[doc="Calls [cos()]."]); activation_impls!(Ln, try_ln, #[doc="Calls [ln()]."]); diff --git a/src/tensor_ops/accurate_gelu/mod.rs b/src/tensor_ops/accurate_gelu/mod.rs index eb40ea5de..b1fef19f4 100644 --- a/src/tensor_ops/accurate_gelu/mod.rs +++ b/src/tensor_ops/accurate_gelu/mod.rs @@ -10,14 +10,28 @@ use crate::{shapes::*, tensor::*}; #[derive(Debug, Default, Copy, Clone)] pub struct AccurateGeLUKernelOp; -/// [Gaussian Linear Unit (GeLU)](https://paperswithcode.com/method/gelu). `x * Phi(x)` +/// [Accurate Gaussian Linear Unit (GeLU)](https://paperswithcode.com/method/gelu). This is defined as `x * Phi(x)` where `Phi(x)` is the cumulative +/// distribution function of a standard normal distribution. This can be calculated via the Error +/// Function `erf(x)` using +/// ```text +/// 0.5 * x * (1.0 + erf(x / 2.0.sqrt())) +/// ``` +/// As an accurate error function is [computationally expensive](https://en.wikipedia.org/wiki/Error_function#Numerical_approximations) it is +/// possible to approximate the Gaussian Linear Unit with a hyperbolic tangent function `tanh` +/// +/// ```text +/// GeLU(x) ~ 0.5 ∗ x ∗ (1.0 + tanh((sqrt(2.0/π) ∗ (x + 0.044715 ∗ x^3))) +/// ``` +/// +/// See [gelu](crate::tensor_ops::gelu::gelu) to use this approximation +/// /// /// Examples: /// ```rust /// # use dfdx::prelude::*; /// # let dev: Cpu = Default::default(); /// let t = dev.tensor([-1.0, 0.0, 1.0, 2.0]); -/// let r = t.gelu_correct(); +/// let r = t.accurate_gelu(); /// ``` pub fn accurate_gelu, T: Tape>( t: Tensor, @@ -28,11 +42,11 @@ pub fn accurate_gelu impl, T: Tape> Tensor { - /// See [gelu] + /// See [accurate_gelu] pub fn accurate_gelu(self) -> Self { self.try_accurate_gelu().unwrap() } - /// See [gelu] + /// See [accurate_gelu] pub fn try_accurate_gelu(self) -> Result { try_unary_op(AccurateGeLUKernelOp, self) } diff --git a/src/tensor_ops/gelu/mod.rs b/src/tensor_ops/gelu/mod.rs index 0abd95e8b..f3b84b41d 100644 --- a/src/tensor_ops/gelu/mod.rs +++ b/src/tensor_ops/gelu/mod.rs @@ -10,7 +10,7 @@ use crate::{shapes::*, tensor::*}; #[derive(Debug, Default, Copy, Clone)] pub struct GeLUKernelOp; -/// [Gaussian Linear Unit (GeLU)](https://paperswithcode.com/method/gelu). `0.5 * x * (1 + tanh(sqrt(2 / pi) * (x + 0.044715 * x^3)))` +/// [Fast Gaussian Linear Unit (GeLU)](https://paperswithcode.com/method/gelu). `0.5 * x * (1 + tanh(sqrt(2 / pi) * (x + 0.044715 * x^3)))` /// /// Examples: /// ```rust