From d201f4ab5fdcc79ee41f93f39e51995b070f3085 Mon Sep 17 00:00:00 2001 From: Jan Cristina Date: Fri, 14 Jul 2023 14:15:54 +0200 Subject: [PATCH] accurate-gelu - describe corresponding pytorch algos for gelus --- src/nn/activations.rs | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/nn/activations.rs b/src/nn/activations.rs index dd5f7b369..b034da9ee 100644 --- a/src/nn/activations.rs +++ b/src/nn/activations.rs @@ -25,8 +25,12 @@ macro_rules! activation_impls { } activation_impls!(ReLU, try_relu, #[doc="Calls [relu()]."]); -activation_impls!(GeLU, try_gelu, #[doc="Calls [gelu()]."]); -activation_impls!(AccurateGeLU, try_accurate_gelu, #[doc="Calls [accurate_gelu()]. The GeLU is defined as x * Phi(x) where Phi is the cumulative distribution function of a standard Normal Distribution. It is often implemented with a fast approximation using tanh (see [GeLU])"]); +activation_impls!(GeLU, try_gelu, #[doc="Calls [gelu()]. This corresponds to `torch.nn.GELU(approximate='tanh')` in pytorch."]); +activation_impls!( + AccurateGeLU, + try_accurate_gelu, + #[doc=r#"Calls [accurate_gelu()]. The GeLU is defined as x * Phi(x) where Phi is the cumulative distribution function of a standard Normal Distribution. +It is often implemented with a fast approximation using tanh (see [GeLU]). This corresponds to pytorch `torch.nn.GELU(approximate='none')` in pytorch."#]); activation_impls!(Sin, try_sin, #[doc="Calls [sin()]."]); activation_impls!(Cos, try_cos, #[doc="Calls [cos()]."]); activation_impls!(Ln, try_ln, #[doc="Calls [ln()]."]);