Skip to content

Commit

Permalink
Merge branch 'silu' into new-base
Browse files Browse the repository at this point in the history
  • Loading branch information
swfsql committed Mar 1, 2024
2 parents ae56771 + bc569c7 commit cd33ed7
Show file tree
Hide file tree
Showing 7 changed files with 160 additions and 0 deletions.
2 changes: 2 additions & 0 deletions dfdx-core/src/tensor_ops/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,7 @@ mod roll;
mod select_and_gather;
mod sgd;
mod sigmoid;
mod silu;
mod sin;
mod slice;
mod softmax;
Expand Down Expand Up @@ -268,6 +269,7 @@ pub use roll::Roll;
pub use select_and_gather::{GatherTo, SelectTo};
pub use sgd::SgdConfig;
pub use sigmoid::sigmoid;
pub use silu::silu;
pub use sin::sin;
pub use slice::slice;
pub use softmax::softmax;
Expand Down
20 changes: 20 additions & 0 deletions dfdx-core/src/tensor_ops/silu/cpu_kernel.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
use crate::tensor_ops::cpu_kernels::UnaryDerivative;

impl<F: num_traits::Float> UnaryDerivative<F> for super::SiLUKernelOp {
const DF_USES_FX: bool = false;
const HAS_CONST_DF: bool = false;

// x / (1 + e^-x)
#[inline(always)]
fn f(&self, x: &F) -> F {
*x / (F::one() + x.neg().exp())
}

// (1 + e^-x + x * e^-x) / (1 + e^-x)^2
// alternative: (e^x (1 + e^x + x)) / (1 + e^x)^2
#[inline(always)]
fn df(&self, x: &F) -> F {
let exp_nx = x.neg().exp();
(F::one() + exp_nx + *x * exp_nx) / (F::one() + exp_nx).powi(2)
}
}
15 changes: 15 additions & 0 deletions dfdx-core/src/tensor_ops/silu/cuda_kernel.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
use super::SiLUKernelOp;
#[allow(unused_imports)]
use crate::dtypes::*;
use crate::tensor_ops::cuda_kernels::cuda_unary;

unsafe impl cudarc::driver::DeviceRepr for SiLUKernelOp {}

const PTX: &str = include_str!(concat!(env!("OUT_DIR"), "/silu.ptx"));

#[cfg(feature = "f16")]
cuda_unary!(SiLUKernelOp, f16, PTX, "silu_fwd_f16", "silu_bwd_f16");
#[cfg(feature = "f16")]
cuda_unary!(SiLUKernelOp, AMP<f16>, PTX, "silu_fwd_f16", "silu_bwd_f16");
cuda_unary!(SiLUKernelOp, f32, PTX, "silu_fwd_f32", "silu_bwd_f32");
cuda_unary!(SiLUKernelOp, f64, PTX, "silu_fwd_f64", "silu_bwd_f64");
62 changes: 62 additions & 0 deletions dfdx-core/src/tensor_ops/silu/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
mod cpu_kernel;

#[cfg(feature = "cuda")]
mod cuda_kernel;

#[cfg(feature = "webgpu")]
mod webgpu_kernel;

use super::ops::{try_unary_op, UnaryKernel};
use crate::{shapes::*, tensor::*};

#[repr(C)]
#[derive(Debug, Default, Copy, Clone)]
pub struct SiLUKernelOp;

/// [Sigmoid-Weighted Linear Unit (SiLU)](https://en.wikipedia.org/wiki/Rectifier_(neural_networks)). `x * x.sigmoid()`
///
/// The derivative is `x * sigmoid'(x) + sigmoid(x)`.
///
/// Examples:
/// ```rust
/// # use dfdx_core::prelude::*;
/// # let dev: Cpu = Default::default();
/// let t = dev.tensor([-1.0, 0.0, 1.0, 2.0]);
/// let r = t.silu();
/// ```
pub fn silu<S: Shape, E: Dtype, D: UnaryKernel<SiLUKernelOp, E>, T: Tape<E, D>>(
t: Tensor<S, E, D, T>,
) -> Tensor<S, E, D, T> {
t.silu()
}

impl<S: Shape, E: Dtype, D: UnaryKernel<SiLUKernelOp, E>, T: Tape<E, D>> Tensor<S, E, D, T> {
/// See [silu]
pub fn silu(self) -> Self {
self.try_silu().unwrap()
}
/// See [silu]
pub fn try_silu(self) -> Result<Self, crate::tensor::Error> {
try_unary_op(SiLUKernelOp, self)
}
}

#[cfg(test)]
mod tests {
use crate::{tensor::*, tensor_ops::*, tests::*};

#[test]
fn test_silu() {
let dev: TestDevice = Default::default();
let x = dev
.tensor([-2.0, -1.0, 0.0, 1.0, 2.0])
.to_dtype::<TestDtype>();
let r = x.leaky_trace().silu();
assert_close_to_literal!(r, [-0.23840584, -0.26894143, 0.0, 0.7310586, 1.761594]);
let g = r.mean().backward();
assert_close_to_literal!(
g.get(&x),
[-0.018156849, 0.014465898, 0.1, 0.1855341, 0.21815684]
);
}
}
32 changes: 32 additions & 0 deletions dfdx-core/src/tensor_ops/silu/silu.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
#include "unary_op_macros.cuh"

struct SiLUKernelOp {};

// x / (1 + e^-x)
template<typename T>
__device__ __forceinline__ T silu_fwd(T x) {
T one = 1.0;
return x / (one + expg(-x));
}

// (1 + e^-x + x * e^-x) / (1 + e^-x)^2
// alternative: (e^x (1 + e^x + x)) / (1 + e^x)^2
template<typename T>
__device__ __forceinline__ T silu_bwd(T x) {
T one = 1.0;
T exp_nx = expg(-x);
T denom_sqrt = (one + exp_nx);
return (one + exp_nx + x * exp_nx) / (denom_sqrt * denom_sqrt);
}

UNARY_OP(__half, silu_fwd_f16, silu_bwd_f16, SiLUKernelOp,
silu_fwd(x),
silu_bwd(x))

UNARY_OP(float, silu_fwd_f32, silu_bwd_f32, SiLUKernelOp,
silu_fwd(x),
silu_bwd(x))

UNARY_OP(double, silu_fwd_f64, silu_bwd_f64, SiLUKernelOp,
silu_fwd(x),
silu_bwd(x))
28 changes: 28 additions & 0 deletions dfdx-core/src/tensor_ops/silu/webgpu_kernel.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
use std::borrow::Cow;

use crate::prelude::{ops::UnaryKernel, Dtype, Webgpu};

impl<E: Dtype> UnaryKernel<super::SiLUKernelOp, E> for Webgpu {
const BACKWARD_WITHOUT_INP: bool = false;

const BACKWARD_WITHOUT_DATA: bool = false;

fn forward<S: crate::prelude::Shape>(
&self,
op: super::SiLUKernelOp,
inp: Cow<crate::prelude::Tensor<S, E, Self>>,
) -> Result<crate::prelude::Tensor<S, E, Self>, crate::prelude::Error> {
todo!()
}

fn backward<S: crate::prelude::Shape>(
&self,
op: super::SiLUKernelOp,
inp: &impl crate::prelude::Tensorlike<S, E, Self>,
grad_inp: &mut Self::Vec,
out: &impl crate::prelude::Tensorlike<S, E, Self>,
grad_out: &Self::Vec,
) -> Result<(), crate::prelude::Error> {
todo!()
}
}
1 change: 1 addition & 0 deletions dfdx-core/src/tensor_ops/utilities/device.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ pub trait Device<E: Dtype>:
+ UnaryKernel<super::super::fast_gelu::FastGeLUKernelOp, E>
+ UnaryKernel<super::super::accurate_gelu::AccurateGeLUKernelOp, E>
+ UnaryKernel<super::super::sigmoid::SigmoidKernelOp, E>
+ UnaryKernel<super::super::silu::SiLUKernelOp, E>
+ UnaryKernel<super::super::sin::SinKernelOp, E>
+ UnaryKernel<super::super::sqrt::SqrtKernelOp, E>
+ UnaryKernel<super::super::square::SquareKernelOp, E>
Expand Down

0 comments on commit cd33ed7

Please sign in to comment.