Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sigmoid with residue #869

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
f1ce9b5
PReLU forward works-ish
opfromthestart Mar 19, 2023
068183c
Doctest works now
opfromthestart Mar 19, 2023
ef62baa
Merge branch 'coreylowman:main' into main
opfromthestart Mar 19, 2023
f2fe0c4
Better-ish implementation
opfromthestart Mar 20, 2023
94f7a20
Merge branch 'main' of github.com:opfromthestart/dfdx into main
opfromthestart Mar 20, 2023
f3ef318
Now has actual layers to use
opfromthestart Mar 20, 2023
b132fcb
fmt, clippy, tests
opfromthestart Mar 20, 2023
a42e0d2
Cleaning
opfromthestart Mar 20, 2023
c248c7e
Remove one unneeded generic
opfromthestart Mar 20, 2023
4caa8d8
Cuda maybe working (idk)
opfromthestart Mar 20, 2023
d9c1346
LeakyReLU now should work
opfromthestart Mar 21, 2023
641f429
Fix test
opfromthestart Mar 21, 2023
feeedff
Conforms to standards
opfromthestart Mar 26, 2023
5378860
Some requested changes
opfromthestart Mar 28, 2023
530c5ff
Formatting
opfromthestart Mar 28, 2023
0604b57
Clippy
opfromthestart Mar 28, 2023
252ae49
another fmt
opfromthestart Mar 28, 2023
fbdb13b
No prelu kernel
opfromthestart Mar 31, 2023
1a8fc60
Fmt
opfromthestart Mar 31, 2023
56dca8b
Explicit zero
opfromthestart Mar 31, 2023
a885559
fmt
opfromthestart Mar 31, 2023
066bcea
Merge branch 'coreylowman:main' into main
opfromthestart Mar 31, 2023
f136a82
Actually exports needed things
opfromthestart Apr 1, 2023
c7caf43
Fixes
opfromthestart Apr 1, 2023
f4660eb
fmt
opfromthestart Apr 1, 2023
b51a12a
Fix nighly reature
opfromthestart Apr 1, 2023
6487470
Separate into own file
opfromthestart Apr 1, 2023
b4faf1c
Better error message
opfromthestart Apr 3, 2023
04b215d
Merge branch 'main' into main
opfromthestart Apr 3, 2023
8fde4e4
Actual better merge
opfromthestart Apr 3, 2023
c73115b
fmt
opfromthestart Apr 3, 2023
2b51b6e
Merge branch 'main' into main
coreylowman Apr 4, 2023
4a4354a
Merge branch 'main' of github.com:opfromthestart/dfdx into main
opfromthestart Apr 4, 2023
7d58f15
Merge branch 'coreylowman:main' into main
opfromthestart Apr 5, 2023
675a380
Reshape
opfromthestart Apr 5, 2023
198c355
Merge branch 'main' of github.com:opfromthestart/dfdx into main
opfromthestart Apr 5, 2023
171fe3b
Remove unnecessary nightly
opfromthestart Apr 5, 2023
b74f3e3
Single tuple impls
opfromthestart Apr 21, 2023
5edff73
Merge branch 'main' of github.com:opfromthestart/dfdx into main
opfromthestart Apr 21, 2023
601b0f5
Better impl
opfromthestart Apr 21, 2023
886ac4f
fmt
opfromthestart Apr 21, 2023
d38df63
Merge branch 'main' of github.com:opfromthestart/dfdx into main
opfromthestart Sep 21, 2023
2c5e602
Added sigmoid with residue
opfromthestart Sep 21, 2023
b5dccc0
Added actual layer
opfromthestart Sep 21, 2023
30bd707
Make gradient on edges smaller
opfromthestart Sep 28, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/nn/activations.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ activation_impls!(Cos, try_cos, #[doc="Calls [cos()]."]);
activation_impls!(Ln, try_ln, #[doc="Calls [ln()]."]);
activation_impls!(Exp, try_exp, #[doc="Calls [exp()]."]);
activation_impls!(Sigmoid, try_sigmoid, #[doc="Calls [sigmoid()]."]);
activation_impls!(Sigmoidr, try_sigmoidr, #[doc="Calls [sigmoidr()]."]);
activation_impls!(Tanh, try_tanh, #[doc="Calls [tanh()]."]);
activation_impls!(Square, try_square, #[doc="Calls [square()]."]);
activation_impls!(Sqrt, try_sqrt, #[doc="Calls [sqrt()]."]);
Expand Down
2 changes: 2 additions & 0 deletions src/tensor_ops/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,7 @@ mod roll;
mod select_and_gather;
mod sgd;
mod sigmoid;
mod sigmoidr;
mod sin;
mod slice;
mod softmax;
Expand Down Expand Up @@ -259,6 +260,7 @@ pub use roll::Roll;
pub use select_and_gather::{GatherTo, SelectTo};
pub use sgd::SgdConfig;
pub use sigmoid::sigmoid;
pub use sigmoidr::sigmoidr;
pub use sin::sin;
pub use slice::slice;
pub use softmax::softmax;
Expand Down
15 changes: 15 additions & 0 deletions src/tensor_ops/sigmoidr/cpu_kernel.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
use crate::tensor_ops::cpu_kernels::UnaryDerivative;

impl<F: num_traits::Float> UnaryDerivative<F> for super::SigmoidrKernelOp {
const DF_USES_FX: bool = true;
const HAS_CONST_DF: bool = false;
#[inline(always)]
fn f(&self, x: &F) -> F {
F::one() / (F::one() + x.neg().exp())
}
#[inline(always)]
fn df(&self, &fx: &F) -> F {
let d = fx * (F::one() - fx);
F::max(d, F::from(0.0000001).unwrap())
}
}
15 changes: 15 additions & 0 deletions src/tensor_ops/sigmoidr/cuda_kernel.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
use super::SigmoidrKernelOp as Sigmoidr;
#[allow(unused_imports)]
use crate::dtypes::*;
use crate::tensor_ops::cuda_kernels::cuda_unary;

unsafe impl cudarc::driver::DeviceRepr for Sigmoidr {}

const PTX: &str = include_str!(concat!(env!("OUT_DIR"), "/sigmoidr.ptx"));

#[cfg(feature = "f16")]
cuda_unary!(df(f(x)) Sigmoidr, f16, PTX, "sigmoidr_fwd_f16", "sigmoidr_bwd_f16");
#[cfg(feature = "f16")]
cuda_unary!(df(f(x)) Sigmoidr, AMP<f16>, PTX, "sigmoidr_fwd_f16", "sigmoidr_bwd_f16");
cuda_unary!(df(f(x)) Sigmoidr, f32, PTX, "sigmoidr_fwd_f32", "sigmoidr_bwd_f32");
cuda_unary!(df(f(x)) Sigmoidr, f64, PTX, "sigmoidr_fwd_f64", "sigmoidr_bwd_f64");
59 changes: 59 additions & 0 deletions src/tensor_ops/sigmoidr/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
mod cpu_kernel;

#[cfg(feature = "cuda")]
mod cuda_kernel;

use super::ops::{try_unary_op, UnaryKernel};
use crate::{shapes::*, tensor::*};

#[repr(C)]
#[derive(Debug, Default, Copy, Clone)]
pub struct SigmoidrKernelOp;

/// [Sigmoid](https://en.wikipedia.org/wiki/Sigmoid_function). `1 / (1 + exp(-t))`.
/// Basically the same as sigmoid but will always return non-zero gradients.
/// The derivative is `sigmoid(t) * (1.0 - sigmoid(t))`.
///
/// Examples:
/// ```rust
/// # use dfdx::prelude::*;
/// # let dev: Cpu = Default::default();
/// let t = dev.tensor([-1.0, 0.0, 1.0, 2.0]);
/// let r = t.sigmoid();
/// ```
pub fn sigmoidr<S: Shape, E: Dtype, D: UnaryKernel<SigmoidrKernelOp, E>, T: Tape<E, D>>(
t: Tensor<S, E, D, T>,
) -> Tensor<S, E, D, T> {
t.sigmoidr()
}

impl<S: Shape, E: Dtype, D: UnaryKernel<SigmoidrKernelOp, E>, T: Tape<E, D>> Tensor<S, E, D, T> {
/// See [sigmoidr]
pub fn sigmoidr(self) -> Self {
self.try_sigmoidr().unwrap()
}
/// See [sigmoidr]
pub fn try_sigmoidr(self) -> Result<Self, D::Err> {
try_unary_op(SigmoidrKernelOp, self)
}
}

#[cfg(test)]
mod tests {
use crate::{tensor::*, tensor_ops::*, tests::*};

#[test]
fn test_sigmoidr() {
let dev: TestDevice = Default::default();
let x = dev
.tensor([-2.0, -1.0, 0.0, 1.0, -TestDtype::INFINITY])
.to_dtype::<TestDtype>();
let r = x.leaky_trace().sigmoidr();
assert_close_to_literal!(r, [0.11920292, 0.26894143, 0.5, 0.7310586, 0.0]);
let g = r.mean().backward();
assert_close_to_literal!(
g.get(&x),
[0.020998716, 0.039322387, 0.05, 0.039322387, 0.00000002]
);
}
}
29 changes: 29 additions & 0 deletions src/tensor_ops/sigmoidr/sigmoidr.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
#include "unary_op_macros.cuh"

struct SigmoidrKernelOp {};

template<typename T>
__device__ __forceinline__ T sigmoidr_fwd(T x) {
T one = 1.0;
return one / (one + expg(-x));
}

template<typename T>
__device__ __forceinline__ T sigmoidr_bwd(T y) {
T one = 1.0;
T d = y * (one - y);
return max(d, 0.0000001);
}

UNARY_OP(__half, sigmoidr_fwd_f16, sigmoidr_bwd_f16, SigmoidrKernelOp,
sigmoidr_fwd(x),
sigmoidr_bwd(y))

UNARY_OP(float, sigmoidr_fwd_f32, sigmoidr_bwd_f32, SigmoidrKernelOp,
sigmoidr_fwd(x),
sigmoidr_bwd(y))

UNARY_OP(double, sigmoidr_fwd_f64, sigmoidr_bwd_f64, SigmoidrKernelOp,
sigmoidr_fwd(x),
sigmoidr_bwd(y))

1 change: 1 addition & 0 deletions src/tensor_ops/utilities/device.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ pub trait Device<E: Dtype>:
+ UnaryKernel<super::super::fast_gelu::FastGeLUKernelOp, E>
+ UnaryKernel<super::super::accurate_gelu::AccurateGeLUKernelOp, E>
+ UnaryKernel<super::super::sigmoid::SigmoidKernelOp, E>
+ UnaryKernel<super::super::sigmoidr::SigmoidrKernelOp, E>
+ UnaryKernel<super::super::sin::SinKernelOp, E>
+ UnaryKernel<super::super::sqrt::SqrtKernelOp, E>
+ UnaryKernel<super::super::square::SquareKernelOp, E>
Expand Down
Loading