From a14b40b6e3a214c716c83a5fbc244c208a767def Mon Sep 17 00:00:00 2001 From: Thiago Machado Date: Wed, 31 Jan 2024 01:10:46 -0500 Subject: [PATCH] add SiLU activation function --- dfdx-core/src/tensor_ops/mod.rs | 2 + dfdx-core/src/tensor_ops/silu/cpu_kernel.rs | 20 ++++++ dfdx-core/src/tensor_ops/silu/cuda_kernel.rs | 15 +++++ dfdx-core/src/tensor_ops/silu/mod.rs | 62 +++++++++++++++++++ dfdx-core/src/tensor_ops/silu/silu.cu | 32 ++++++++++ .../src/tensor_ops/silu/webgpu_kernel.rs | 28 +++++++++ dfdx-core/src/tensor_ops/utilities/device.rs | 1 + 7 files changed, 160 insertions(+) create mode 100644 dfdx-core/src/tensor_ops/silu/cpu_kernel.rs create mode 100644 dfdx-core/src/tensor_ops/silu/cuda_kernel.rs create mode 100644 dfdx-core/src/tensor_ops/silu/mod.rs create mode 100644 dfdx-core/src/tensor_ops/silu/silu.cu create mode 100644 dfdx-core/src/tensor_ops/silu/webgpu_kernel.rs diff --git a/dfdx-core/src/tensor_ops/mod.rs b/dfdx-core/src/tensor_ops/mod.rs index 453457f4..1cb0f38c 100644 --- a/dfdx-core/src/tensor_ops/mod.rs +++ b/dfdx-core/src/tensor_ops/mod.rs @@ -197,6 +197,7 @@ mod roll; mod select_and_gather; mod sgd; mod sigmoid; +mod silu; mod sin; mod slice; mod softmax; @@ -264,6 +265,7 @@ pub use roll::Roll; pub use select_and_gather::{GatherTo, SelectTo}; pub use sgd::SgdConfig; pub use sigmoid::sigmoid; +pub use silu::silu; pub use sin::sin; pub use slice::slice; pub use softmax::softmax; diff --git a/dfdx-core/src/tensor_ops/silu/cpu_kernel.rs b/dfdx-core/src/tensor_ops/silu/cpu_kernel.rs new file mode 100644 index 00000000..f6f05752 --- /dev/null +++ b/dfdx-core/src/tensor_ops/silu/cpu_kernel.rs @@ -0,0 +1,20 @@ +use crate::tensor_ops::cpu_kernels::UnaryDerivative; + +impl UnaryDerivative for super::SiLUKernelOp { + const DF_USES_FX: bool = false; + const HAS_CONST_DF: bool = false; + + // x / (1 + e^-x) + #[inline(always)] + fn f(&self, x: &F) -> F { + *x / (F::one() + x.neg().exp()) + } + + // (1 + e^-x + x * e^-x) / (1 + e^-x)^2 + // alternative: (e^x (1 + e^x + x)) / (1 + e^x)^2 + #[inline(always)] + fn df(&self, x: &F) -> F { + let exp_nx = x.neg().exp(); + F::one() + exp_nx + *x * exp_nx / (F::one() + exp_nx).powi(2) + } +} diff --git a/dfdx-core/src/tensor_ops/silu/cuda_kernel.rs b/dfdx-core/src/tensor_ops/silu/cuda_kernel.rs new file mode 100644 index 00000000..45bf1385 --- /dev/null +++ b/dfdx-core/src/tensor_ops/silu/cuda_kernel.rs @@ -0,0 +1,15 @@ +use super::SiLUKernelOp; +#[allow(unused_imports)] +use crate::dtypes::*; +use crate::tensor_ops::cuda_kernels::cuda_unary; + +unsafe impl cudarc::driver::DeviceRepr for SiLUKernelOp {} + +const PTX: &str = include_str!(concat!(env!("OUT_DIR"), "/silu.ptx")); + +#[cfg(feature = "f16")] +cuda_unary!(SiLUKernelOp, f16, PTX, "silu_fwd_f16", "silu_bwd_f16"); +#[cfg(feature = "f16")] +cuda_unary!(SiLUKernelOp, AMP, PTX, "silu_fwd_f16", "silu_bwd_f16"); +cuda_unary!(SiLUKernelOp, f32, PTX, "silu_fwd_f32", "silu_bwd_f32"); +cuda_unary!(SiLUKernelOp, f64, PTX, "silu_fwd_f64", "silu_bwd_f64"); diff --git a/dfdx-core/src/tensor_ops/silu/mod.rs b/dfdx-core/src/tensor_ops/silu/mod.rs new file mode 100644 index 00000000..97bcce10 --- /dev/null +++ b/dfdx-core/src/tensor_ops/silu/mod.rs @@ -0,0 +1,62 @@ +mod cpu_kernel; + +#[cfg(feature = "cuda")] +mod cuda_kernel; + +#[cfg(feature = "webgpu")] +mod webgpu_kernel; + +use super::ops::{try_unary_op, UnaryKernel}; +use crate::{shapes::*, tensor::*}; + +#[repr(C)] +#[derive(Debug, Default, Copy, Clone)] +pub struct SiLUKernelOp; + +/// [Sigmoid-Weighted Linear Unit (SiLU)](https://en.wikipedia.org/wiki/Rectifier_(neural_networks)). `x * x.sigmoid()` +/// +/// The derivative is `x * sigmoid'(x) + sigmoid(x)`. +/// +/// Examples: +/// ```rust +/// # use dfdx_core::prelude::*; +/// # let dev: Cpu = Default::default(); +/// let t = dev.tensor([-1.0, 0.0, 1.0, 2.0]); +/// let r = t.silu(); +/// ``` +pub fn silu, T: Tape>( + t: Tensor, +) -> Tensor { + t.silu() +} + +impl, T: Tape> Tensor { + /// See [silu] + pub fn silu(self) -> Self { + self.try_silu().unwrap() + } + /// See [silu] + pub fn try_silu(self) -> Result { + try_unary_op(SiLUKernelOp, self) + } +} + +#[cfg(test)] +mod tests { + use crate::{tensor::*, tensor_ops::*, tests::*}; + + #[test] + fn test_silu() { + let dev: TestDevice = Default::default(); + let x = dev + .tensor([-2.0, -1.0, 0.0, 1.0, 2.0]) + .to_dtype::(); + let r = x.leaky_trace().silu(); + assert_close_to_literal!(r, [-0.23840584, -0.26894143, 0.0, 0.7310586, 1.761594]); + let g = r.mean().backward(); + assert_close_to_literal!( + g.get(&x), + [1.635814, 0.70433396, 0.4, 0.31289828, 0.26906452] + ); + } +} diff --git a/dfdx-core/src/tensor_ops/silu/silu.cu b/dfdx-core/src/tensor_ops/silu/silu.cu new file mode 100644 index 00000000..d3b01a7e --- /dev/null +++ b/dfdx-core/src/tensor_ops/silu/silu.cu @@ -0,0 +1,32 @@ +#include "unary_op_macros.cuh" + +struct SiLUKernelOp {}; + +// x / (1 + e^-x) +template +__device__ __forceinline__ T silu_fwd(T x) { + T one = 1.0; + return x / (one + expg(-x)); +} + +// (1 + e^-x + x * e^-x) / (1 + e^-x)^2 +// alternative: (e^x (1 + e^x + x)) / (1 + e^x)^2 +template +__device__ __forceinline__ T silu_bwd(T x) { + T one = 1.0; + T exp_nx = expg(-x); + T denom_sqrt = (one + exp_nx); + return (one + exp_nx + x * exp_nx) / (denom_sqrt * denom_sqrt); +} + +UNARY_OP(__half, silu_fwd_f16, silu_bwd_f16, SiLUKernelOp, + silu_fwd(x), + silu_bwd(x)) + +UNARY_OP(float, silu_fwd_f32, silu_bwd_f32, SiLUKernelOp, + silu_fwd(x), + silu_bwd(x)) + +UNARY_OP(double, silu_fwd_f64, silu_bwd_f64, SiLUKernelOp, + silu_fwd(x), + silu_bwd(x)) diff --git a/dfdx-core/src/tensor_ops/silu/webgpu_kernel.rs b/dfdx-core/src/tensor_ops/silu/webgpu_kernel.rs new file mode 100644 index 00000000..438850e1 --- /dev/null +++ b/dfdx-core/src/tensor_ops/silu/webgpu_kernel.rs @@ -0,0 +1,28 @@ +use std::borrow::Cow; + +use crate::prelude::{ops::UnaryKernel, Dtype, Webgpu}; + +impl UnaryKernel for Webgpu { + const BACKWARD_WITHOUT_INP: bool = false; + + const BACKWARD_WITHOUT_DATA: bool = false; + + fn forward( + &self, + op: super::SiLUKernelOp, + inp: Cow>, + ) -> Result, crate::prelude::Error> { + todo!() + } + + fn backward( + &self, + op: super::SiLUKernelOp, + inp: &impl crate::prelude::Tensorlike, + grad_inp: &mut Self::Vec, + out: &impl crate::prelude::Tensorlike, + grad_out: &Self::Vec, + ) -> Result<(), crate::prelude::Error> { + todo!() + } +} diff --git a/dfdx-core/src/tensor_ops/utilities/device.rs b/dfdx-core/src/tensor_ops/utilities/device.rs index 8cbc2137..6e9b6ec4 100644 --- a/dfdx-core/src/tensor_ops/utilities/device.rs +++ b/dfdx-core/src/tensor_ops/utilities/device.rs @@ -92,6 +92,7 @@ pub trait Device: + UnaryKernel + UnaryKernel + UnaryKernel + + UnaryKernel + UnaryKernel + UnaryKernel + UnaryKernel