From 5c532ec5dc51cd17cd4bb9ae940ecf2c9baf89f6 Mon Sep 17 00:00:00 2001 From: rainiwu Date: Fri, 26 Jan 2024 00:29:35 -0800 Subject: [PATCH 1/7] remove deprecated ftz intrinsics --- dfdx-core/src/lib.rs | 38 -------------------------------------- dfdx/examples/12-mnist.rs | 3 --- 2 files changed, 41 deletions(-) diff --git a/dfdx-core/src/lib.rs b/dfdx-core/src/lib.rs index 31e61643..c126db2c 100644 --- a/dfdx-core/src/lib.rs +++ b/dfdx-core/src/lib.rs @@ -128,44 +128,6 @@ pub mod prelude { pub use crate::tensor_ops::*; } -/// Sets a CPU `sse` flag to flush denormal floating point numbers to zero. The opposite of this is [keep_denormals()]. -/// -/// Some resources: -/// 1. [Effects of Flush-To-Zero mode](https://developer.arm.com/documentation/dui0473/c/neon-and-vfp-programming/the-effects-of-using-flush-to-zero-mode?lang=en) -/// 2. [When to use Flush-To-Zero mode](https://developer.arm.com/documentation/dui0473/c/neon-and-vfp-programming/when-to-use-flush-to-zero-mode?lang=en) -pub fn flush_denormals_to_zero() { - #[cfg(all(target_arch = "x86", target_feature = "sse"))] - { - use std::arch::x86::{_MM_FLUSH_ZERO_ON, _MM_SET_FLUSH_ZERO_MODE}; - unsafe { _MM_SET_FLUSH_ZERO_MODE(_MM_FLUSH_ZERO_ON) } - } - - #[cfg(all(target_arch = "x86_64", target_feature = "sse"))] - { - use std::arch::x86_64::{_MM_FLUSH_ZERO_ON, _MM_SET_FLUSH_ZERO_MODE}; - unsafe { _MM_SET_FLUSH_ZERO_MODE(_MM_FLUSH_ZERO_ON) } - } -} - -/// Sets a CPU flag to keep denormal floating point numbers. The opposite of this is [flush_denormals_to_zero()]. -/// -/// Some resources: -/// 1. [Effects of Flush-To-Zero mode](https://developer.arm.com/documentation/dui0473/c/neon-and-vfp-programming/the-effects-of-using-flush-to-zero-mode?lang=en) -/// 2. [When to use Flush-To-Zero mode](https://developer.arm.com/documentation/dui0473/c/neon-and-vfp-programming/when-to-use-flush-to-zero-mode?lang=en) -pub fn keep_denormals() { - #[cfg(all(target_arch = "x86", target_feature = "sse"))] - { - use std::arch::x86::{_MM_FLUSH_ZERO_OFF, _MM_SET_FLUSH_ZERO_MODE}; - unsafe { _MM_SET_FLUSH_ZERO_MODE(_MM_FLUSH_ZERO_OFF) } - } - - #[cfg(all(target_arch = "x86_64", target_feature = "sse"))] - { - use std::arch::x86_64::{_MM_FLUSH_ZERO_OFF, _MM_SET_FLUSH_ZERO_MODE}; - unsafe { _MM_SET_FLUSH_ZERO_MODE(_MM_FLUSH_ZERO_OFF) } - } -} - #[cfg(test)] pub(crate) mod tests { pub use num_traits::{Float, NumCast, Zero}; diff --git a/dfdx/examples/12-mnist.rs b/dfdx/examples/12-mnist.rs index 705d14c8..00d43452 100644 --- a/dfdx/examples/12-mnist.rs +++ b/dfdx/examples/12-mnist.rs @@ -62,9 +62,6 @@ type Mlp = ( const BATCH_SIZE: usize = 32; fn main() { - // ftz substantially improves performance - dfdx::flush_denormals_to_zero(); - let mnist_path = std::env::args() .nth(1) .unwrap_or_else(|| "./datasets/MNIST/raw".to_string()); From fb91f13314fb24a67c2d8e14ad40345d2d334805 Mon Sep 17 00:00:00 2001 From: rainiwu Date: Fri, 26 Jan 2024 00:55:48 -0800 Subject: [PATCH 2/7] suppress spurious cargo clippy warning --- dfdx-core/src/data/collate.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/dfdx-core/src/data/collate.rs b/dfdx-core/src/data/collate.rs index d38a2a67..5f52d636 100644 --- a/dfdx-core/src/data/collate.rs +++ b/dfdx-core/src/data/collate.rs @@ -55,6 +55,7 @@ impl Collate for Vec<(A, B)> { impl<'a, A, B> Collate for Vec<&'a (A, B)> { type Collated = (Vec<&'a A>, Vec<&'a B>); fn collated(self) -> Self::Collated { + #[allow(clippy::map_identity)] self.into_iter().map(|(a, b)| (a, b)).unzip() } } From 4e3f7c7a24728668f72cf3617a66f4476280f6fb Mon Sep 17 00:00:00 2001 From: Thiago Machado Date: Tue, 6 Feb 2024 18:27:46 -0500 Subject: [PATCH 3/7] avoid conv1d bound for cudnn --- dfdx-core/src/tensor_ops/utilities/device.rs | 50 +++++++++++++++----- 1 file changed, 39 insertions(+), 11 deletions(-) diff --git a/dfdx-core/src/tensor_ops/utilities/device.rs b/dfdx-core/src/tensor_ops/utilities/device.rs index 8cbc2137..91f87cf6 100644 --- a/dfdx-core/src/tensor_ops/utilities/device.rs +++ b/dfdx-core/src/tensor_ops/utilities/device.rs @@ -114,25 +114,49 @@ pub trait Device: + crate::tensor_ops::axpy::AxpyKernel // conv1d - + super::super::conv1d::Conv1DKernel + + NonCudnnCuda +{ +} + +#[cfg(feature = "cudnn")] +pub trait NonCudnnCuda {} + +#[cfg(not(feature = "cudnn"))] +pub trait NonCudnnCuda: + // conv1d + super::super::conv1d::Conv1DKernel { } #[cfg(feature = "f16")] -impl Device for crate::tensor::Cpu {} -#[cfg(feature = "f16")] -impl Device> for crate::tensor::Cpu {} +mod f16_ { + use super::*; + impl Device for crate::tensor::Cpu {} + impl NonCudnnCuda for crate::tensor::Cpu {} + impl Device> for crate::tensor::Cpu {} + impl NonCudnnCuda> for crate::tensor::Cpu {} +} impl Device for crate::tensor::Cpu {} +impl NonCudnnCuda for crate::tensor::Cpu {} impl Device for crate::tensor::Cpu {} +impl NonCudnnCuda for crate::tensor::Cpu {} #[cfg(all(feature = "cuda", feature = "f16"))] -impl Device for crate::tensor::Cuda {} -#[cfg(all(feature = "cuda", feature = "f16"))] -impl Device> for crate::tensor::Cuda {} -#[cfg(feature = "cuda")] -impl Device for crate::tensor::Cuda {} +mod cuda_f16 { + use super::*; + impl Device for crate::tensor::Cuda {} + impl NonCudnnCuda for crate::tensor::Cuda {} + impl Device> for crate::tensor::Cuda {} + impl NonCudnnCuda> for crate::tensor::Cuda {} +} #[cfg(feature = "cuda")] -impl Device for crate::tensor::Cuda {} +mod cuda { + use super::*; + impl Device for crate::tensor::Cuda {} + impl NonCudnnCuda for crate::tensor::Cuda {} + impl Device for crate::tensor::Cuda {} + impl NonCudnnCuda for crate::tensor::Cuda {} +} // TODO: How can we implement this for f16 when WGSL doesn't support f16 yet? // #[cfg(all(feature = "webgpu", feature = "f16"))] @@ -140,7 +164,11 @@ impl Device for crate::tensor::Cuda {} // #[cfg(all(feature = "webgpu", feature = "f16"))] // impl Device> for crate::tensor::Webgpu {} #[cfg(feature = "webgpu")] -impl Device for crate::tensor::Webgpu {} +mod webgpu { + use super::*; + impl Device for crate::tensor::Webgpu {} + impl NonCudnnCuda for crate::tensor::Webgpu {} +} // TODO: How can we implement this for f64 when WGSL doesn't support f64 yet? // #[cfg(feature = "webgpu")] From a8bc54c5c8e02c68fe09e72fc94ba0a8b3273b9a Mon Sep 17 00:00:00 2001 From: Thiago Machado Date: Fri, 9 Feb 2024 11:53:40 -0500 Subject: [PATCH 4/7] bump gemm --- dfdx-core/Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dfdx-core/Cargo.toml b/dfdx-core/Cargo.toml index 5309ef7c..0f6cd5c6 100644 --- a/dfdx-core/Cargo.toml +++ b/dfdx-core/Cargo.toml @@ -35,7 +35,7 @@ num-traits = { workspace = true } safetensors = { workspace = true, optional = true } memmap2 = { workspace = true, optional = true } half = { version = "2.3.1", optional = true, features = ["num-traits", "rand_distr"] } -gemm = { version = "0.16.14", default-features = false, optional = true, features = ["rayon"] } +gemm = { version = "0.17.1", default-features = false, optional = true, features = ["rayon"] } rayon = { version = "1.7.0", optional = true } libm = { workspace = true } wgpu = { version = "0.18.0", features = ["glsl", "spirv"], optional = true } From 557687c0a9e29dfba2311fe67414863c6c5137bf Mon Sep 17 00:00:00 2001 From: Thiago Machado Date: Fri, 9 Feb 2024 12:52:05 -0500 Subject: [PATCH 5/7] clippy fix --- dfdx-core/src/tensor/gradients.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dfdx-core/src/tensor/gradients.rs b/dfdx-core/src/tensor/gradients.rs index 86974ec6..d24e2e32 100644 --- a/dfdx-core/src/tensor/gradients.rs +++ b/dfdx-core/src/tensor/gradients.rs @@ -153,7 +153,7 @@ impl> Gradients { #[inline] pub(crate) fn many_and_ref( &mut self, - ls: &Vec>, + ls: &[impl Tensorlike], r: &impl Tensorlike, ) -> (Vec<&mut D::Vec>, &D::Vec) { for i in 0..ls.len() { From ff9cc12edd107527f996b69d1446f1bb5037a76f Mon Sep 17 00:00:00 2001 From: Thiago Machado Date: Wed, 31 Jan 2024 01:10:46 -0500 Subject: [PATCH 6/7] add SiLU activation function --- dfdx-core/src/tensor_ops/mod.rs | 2 + dfdx-core/src/tensor_ops/silu/cpu_kernel.rs | 20 ++++++ dfdx-core/src/tensor_ops/silu/cuda_kernel.rs | 15 +++++ dfdx-core/src/tensor_ops/silu/mod.rs | 62 +++++++++++++++++++ dfdx-core/src/tensor_ops/silu/silu.cu | 32 ++++++++++ .../src/tensor_ops/silu/webgpu_kernel.rs | 28 +++++++++ dfdx-core/src/tensor_ops/utilities/device.rs | 1 + 7 files changed, 160 insertions(+) create mode 100644 dfdx-core/src/tensor_ops/silu/cpu_kernel.rs create mode 100644 dfdx-core/src/tensor_ops/silu/cuda_kernel.rs create mode 100644 dfdx-core/src/tensor_ops/silu/mod.rs create mode 100644 dfdx-core/src/tensor_ops/silu/silu.cu create mode 100644 dfdx-core/src/tensor_ops/silu/webgpu_kernel.rs diff --git a/dfdx-core/src/tensor_ops/mod.rs b/dfdx-core/src/tensor_ops/mod.rs index 453457f4..1cb0f38c 100644 --- a/dfdx-core/src/tensor_ops/mod.rs +++ b/dfdx-core/src/tensor_ops/mod.rs @@ -197,6 +197,7 @@ mod roll; mod select_and_gather; mod sgd; mod sigmoid; +mod silu; mod sin; mod slice; mod softmax; @@ -264,6 +265,7 @@ pub use roll::Roll; pub use select_and_gather::{GatherTo, SelectTo}; pub use sgd::SgdConfig; pub use sigmoid::sigmoid; +pub use silu::silu; pub use sin::sin; pub use slice::slice; pub use softmax::softmax; diff --git a/dfdx-core/src/tensor_ops/silu/cpu_kernel.rs b/dfdx-core/src/tensor_ops/silu/cpu_kernel.rs new file mode 100644 index 00000000..f6f05752 --- /dev/null +++ b/dfdx-core/src/tensor_ops/silu/cpu_kernel.rs @@ -0,0 +1,20 @@ +use crate::tensor_ops::cpu_kernels::UnaryDerivative; + +impl UnaryDerivative for super::SiLUKernelOp { + const DF_USES_FX: bool = false; + const HAS_CONST_DF: bool = false; + + // x / (1 + e^-x) + #[inline(always)] + fn f(&self, x: &F) -> F { + *x / (F::one() + x.neg().exp()) + } + + // (1 + e^-x + x * e^-x) / (1 + e^-x)^2 + // alternative: (e^x (1 + e^x + x)) / (1 + e^x)^2 + #[inline(always)] + fn df(&self, x: &F) -> F { + let exp_nx = x.neg().exp(); + F::one() + exp_nx + *x * exp_nx / (F::one() + exp_nx).powi(2) + } +} diff --git a/dfdx-core/src/tensor_ops/silu/cuda_kernel.rs b/dfdx-core/src/tensor_ops/silu/cuda_kernel.rs new file mode 100644 index 00000000..45bf1385 --- /dev/null +++ b/dfdx-core/src/tensor_ops/silu/cuda_kernel.rs @@ -0,0 +1,15 @@ +use super::SiLUKernelOp; +#[allow(unused_imports)] +use crate::dtypes::*; +use crate::tensor_ops::cuda_kernels::cuda_unary; + +unsafe impl cudarc::driver::DeviceRepr for SiLUKernelOp {} + +const PTX: &str = include_str!(concat!(env!("OUT_DIR"), "/silu.ptx")); + +#[cfg(feature = "f16")] +cuda_unary!(SiLUKernelOp, f16, PTX, "silu_fwd_f16", "silu_bwd_f16"); +#[cfg(feature = "f16")] +cuda_unary!(SiLUKernelOp, AMP, PTX, "silu_fwd_f16", "silu_bwd_f16"); +cuda_unary!(SiLUKernelOp, f32, PTX, "silu_fwd_f32", "silu_bwd_f32"); +cuda_unary!(SiLUKernelOp, f64, PTX, "silu_fwd_f64", "silu_bwd_f64"); diff --git a/dfdx-core/src/tensor_ops/silu/mod.rs b/dfdx-core/src/tensor_ops/silu/mod.rs new file mode 100644 index 00000000..97bcce10 --- /dev/null +++ b/dfdx-core/src/tensor_ops/silu/mod.rs @@ -0,0 +1,62 @@ +mod cpu_kernel; + +#[cfg(feature = "cuda")] +mod cuda_kernel; + +#[cfg(feature = "webgpu")] +mod webgpu_kernel; + +use super::ops::{try_unary_op, UnaryKernel}; +use crate::{shapes::*, tensor::*}; + +#[repr(C)] +#[derive(Debug, Default, Copy, Clone)] +pub struct SiLUKernelOp; + +/// [Sigmoid-Weighted Linear Unit (SiLU)](https://en.wikipedia.org/wiki/Rectifier_(neural_networks)). `x * x.sigmoid()` +/// +/// The derivative is `x * sigmoid'(x) + sigmoid(x)`. +/// +/// Examples: +/// ```rust +/// # use dfdx_core::prelude::*; +/// # let dev: Cpu = Default::default(); +/// let t = dev.tensor([-1.0, 0.0, 1.0, 2.0]); +/// let r = t.silu(); +/// ``` +pub fn silu, T: Tape>( + t: Tensor, +) -> Tensor { + t.silu() +} + +impl, T: Tape> Tensor { + /// See [silu] + pub fn silu(self) -> Self { + self.try_silu().unwrap() + } + /// See [silu] + pub fn try_silu(self) -> Result { + try_unary_op(SiLUKernelOp, self) + } +} + +#[cfg(test)] +mod tests { + use crate::{tensor::*, tensor_ops::*, tests::*}; + + #[test] + fn test_silu() { + let dev: TestDevice = Default::default(); + let x = dev + .tensor([-2.0, -1.0, 0.0, 1.0, 2.0]) + .to_dtype::(); + let r = x.leaky_trace().silu(); + assert_close_to_literal!(r, [-0.23840584, -0.26894143, 0.0, 0.7310586, 1.761594]); + let g = r.mean().backward(); + assert_close_to_literal!( + g.get(&x), + [1.635814, 0.70433396, 0.4, 0.31289828, 0.26906452] + ); + } +} diff --git a/dfdx-core/src/tensor_ops/silu/silu.cu b/dfdx-core/src/tensor_ops/silu/silu.cu new file mode 100644 index 00000000..d3b01a7e --- /dev/null +++ b/dfdx-core/src/tensor_ops/silu/silu.cu @@ -0,0 +1,32 @@ +#include "unary_op_macros.cuh" + +struct SiLUKernelOp {}; + +// x / (1 + e^-x) +template +__device__ __forceinline__ T silu_fwd(T x) { + T one = 1.0; + return x / (one + expg(-x)); +} + +// (1 + e^-x + x * e^-x) / (1 + e^-x)^2 +// alternative: (e^x (1 + e^x + x)) / (1 + e^x)^2 +template +__device__ __forceinline__ T silu_bwd(T x) { + T one = 1.0; + T exp_nx = expg(-x); + T denom_sqrt = (one + exp_nx); + return (one + exp_nx + x * exp_nx) / (denom_sqrt * denom_sqrt); +} + +UNARY_OP(__half, silu_fwd_f16, silu_bwd_f16, SiLUKernelOp, + silu_fwd(x), + silu_bwd(x)) + +UNARY_OP(float, silu_fwd_f32, silu_bwd_f32, SiLUKernelOp, + silu_fwd(x), + silu_bwd(x)) + +UNARY_OP(double, silu_fwd_f64, silu_bwd_f64, SiLUKernelOp, + silu_fwd(x), + silu_bwd(x)) diff --git a/dfdx-core/src/tensor_ops/silu/webgpu_kernel.rs b/dfdx-core/src/tensor_ops/silu/webgpu_kernel.rs new file mode 100644 index 00000000..438850e1 --- /dev/null +++ b/dfdx-core/src/tensor_ops/silu/webgpu_kernel.rs @@ -0,0 +1,28 @@ +use std::borrow::Cow; + +use crate::prelude::{ops::UnaryKernel, Dtype, Webgpu}; + +impl UnaryKernel for Webgpu { + const BACKWARD_WITHOUT_INP: bool = false; + + const BACKWARD_WITHOUT_DATA: bool = false; + + fn forward( + &self, + op: super::SiLUKernelOp, + inp: Cow>, + ) -> Result, crate::prelude::Error> { + todo!() + } + + fn backward( + &self, + op: super::SiLUKernelOp, + inp: &impl crate::prelude::Tensorlike, + grad_inp: &mut Self::Vec, + out: &impl crate::prelude::Tensorlike, + grad_out: &Self::Vec, + ) -> Result<(), crate::prelude::Error> { + todo!() + } +} diff --git a/dfdx-core/src/tensor_ops/utilities/device.rs b/dfdx-core/src/tensor_ops/utilities/device.rs index 91f87cf6..495b68fa 100644 --- a/dfdx-core/src/tensor_ops/utilities/device.rs +++ b/dfdx-core/src/tensor_ops/utilities/device.rs @@ -92,6 +92,7 @@ pub trait Device: + UnaryKernel + UnaryKernel + UnaryKernel + + UnaryKernel + UnaryKernel + UnaryKernel + UnaryKernel From bc569c7515ce5c8ddaff23eaf19cad085f314c85 Mon Sep 17 00:00:00 2001 From: Thiago Machado Date: Mon, 19 Feb 2024 21:47:22 -0500 Subject: [PATCH 7/7] silu: fix cpu df --- dfdx-core/src/tensor_ops/silu/cpu_kernel.rs | 2 +- dfdx-core/src/tensor_ops/silu/mod.rs | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/dfdx-core/src/tensor_ops/silu/cpu_kernel.rs b/dfdx-core/src/tensor_ops/silu/cpu_kernel.rs index f6f05752..2fcba2fb 100644 --- a/dfdx-core/src/tensor_ops/silu/cpu_kernel.rs +++ b/dfdx-core/src/tensor_ops/silu/cpu_kernel.rs @@ -15,6 +15,6 @@ impl UnaryDerivative for super::SiLUKernelOp { #[inline(always)] fn df(&self, x: &F) -> F { let exp_nx = x.neg().exp(); - F::one() + exp_nx + *x * exp_nx / (F::one() + exp_nx).powi(2) + (F::one() + exp_nx + *x * exp_nx) / (F::one() + exp_nx).powi(2) } } diff --git a/dfdx-core/src/tensor_ops/silu/mod.rs b/dfdx-core/src/tensor_ops/silu/mod.rs index 97bcce10..53881079 100644 --- a/dfdx-core/src/tensor_ops/silu/mod.rs +++ b/dfdx-core/src/tensor_ops/silu/mod.rs @@ -56,7 +56,7 @@ mod tests { let g = r.mean().backward(); assert_close_to_literal!( g.get(&x), - [1.635814, 0.70433396, 0.4, 0.31289828, 0.26906452] + [-0.018156849, 0.014465898, 0.1, 0.1855341, 0.21815684] ); } }