Skip to content

Commit

Permalink
accurate-gelu - more documentation, fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
jcrist1 committed Jul 14, 2023
1 parent b42b5e3 commit 8fbf249
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 6 deletions.
2 changes: 1 addition & 1 deletion src/nn/activations.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ macro_rules! activation_impls {

activation_impls!(ReLU, try_relu, #[doc="Calls [relu()]."]);
activation_impls!(GeLU, try_gelu, #[doc="Calls [gelu()]."]);
activation_impls!(AccurateGeLU, try_accurate_gelu, #[doc="Calls [accurate_gelu()]."]);
activation_impls!(AccurateGeLU, try_accurate_gelu, #[doc="Calls [accurate_gelu()]. The GeLU is defined as x * Phi(x) where Phi is the cumulative distribution function of a standard Normal Distribution. It is often implemented with a fast approximation using tanh (see [GeLU])"]);
activation_impls!(Sin, try_sin, #[doc="Calls [sin()]."]);
activation_impls!(Cos, try_cos, #[doc="Calls [cos()]."]);
activation_impls!(Ln, try_ln, #[doc="Calls [ln()]."]);
Expand Down
22 changes: 18 additions & 4 deletions src/tensor_ops/accurate_gelu/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,28 @@ use crate::{shapes::*, tensor::*};
#[derive(Debug, Default, Copy, Clone)]
pub struct AccurateGeLUKernelOp;

/// [Gaussian Linear Unit (GeLU)](https://paperswithcode.com/method/gelu). `x * Phi(x)`
/// [Accurate Gaussian Linear Unit (GeLU)](https://paperswithcode.com/method/gelu). This is defined as `x * Phi(x)` where `Phi(x)` is the cumulative
/// distribution function of a standard normal distribution. This can be calculated via the Error
/// Function `erf(x)` using
/// ```text
/// 0.5 * x * (1.0 + erf(x / 2.0.sqrt()))
/// ```
/// As an accurate error function is [computationally expensive](https://en.wikipedia.org/wiki/Error_function#Numerical_approximations) it is
/// possible to approximate the Gaussian Linear Unit with a hyperbolic tangent function `tanh`
///
/// ```text
/// GeLU(x) ~ 0.5 ∗ x ∗ (1.0 + tanh((sqrt(2.0/π) ∗ (x + 0.044715 ∗ x^3)))
/// ```
///
/// See [gelu](crate::tensor_ops::gelu::gelu) to use this approximation
///
///
/// Examples:
/// ```rust
/// # use dfdx::prelude::*;
/// # let dev: Cpu = Default::default();
/// let t = dev.tensor([-1.0, 0.0, 1.0, 2.0]);
/// let r = t.gelu_correct();
/// let r = t.accurate_gelu();
/// ```
pub fn accurate_gelu<S: Shape, E: Dtype, D: UnaryKernel<AccurateGeLUKernelOp, E>, T: Tape<E, D>>(
t: Tensor<S, E, D, T>,
Expand All @@ -28,11 +42,11 @@ pub fn accurate_gelu<S: Shape, E: Dtype, D: UnaryKernel<AccurateGeLUKernelOp, E>
impl<S: Shape, E: Dtype, D: UnaryKernel<AccurateGeLUKernelOp, E>, T: Tape<E, D>>
Tensor<S, E, D, T>
{
/// See [gelu]
/// See [accurate_gelu]
pub fn accurate_gelu(self) -> Self {
self.try_accurate_gelu().unwrap()
}
/// See [gelu]
/// See [accurate_gelu]
pub fn try_accurate_gelu(self) -> Result<Self, D::Err> {
try_unary_op(AccurateGeLUKernelOp, self)
}
Expand Down
2 changes: 1 addition & 1 deletion src/tensor_ops/gelu/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use crate::{shapes::*, tensor::*};
#[derive(Debug, Default, Copy, Clone)]
pub struct GeLUKernelOp;

/// [Gaussian Linear Unit (GeLU)](https://paperswithcode.com/method/gelu). `0.5 * x * (1 + tanh(sqrt(2 / pi) * (x + 0.044715 * x^3)))`
/// [Fast Gaussian Linear Unit (GeLU)](https://paperswithcode.com/method/gelu). `0.5 * x * (1 + tanh(sqrt(2 / pi) * (x + 0.044715 * x^3)))`
///
/// Examples:
/// ```rust
Expand Down

0 comments on commit 8fbf249

Please sign in to comment.