From 0b49672306213863571554c46c41fb2f3a0c3438 Mon Sep 17 00:00:00 2001 From: Corey Lowman Date: Wed, 26 Jul 2023 08:48:25 -0400 Subject: [PATCH] [Breaking] Adds `AMP` dtype (#811) * Adds AMP dtype * impl sum for amp cpu * impl amp kernels for cpu optimizers * Moving NotMixedPrecision to dtypes * Adding AMP implementations for cuda kernels * Fixing cuda errors & warnings * Adds Gemm impl for AMP for CudaBlas * Adding chunk_sum for amp f16 * bump cudarc version * Update src/dtypes/amp.rs Co-authored-by: nkoppel * More generic AMP * Fixing unused imports --------- Co-authored-by: nkoppel --- Cargo.toml | 3 +- src/dtypes/amp.rs | 588 ++++++++++++++++++ src/dtypes/mod.rs | 119 ++++ src/lib.rs | 26 +- src/shapes/mod.rs | 3 +- src/shapes/shape.rs | 98 --- src/tensor_ops/abs/cuda_kernel.rs | 6 +- src/tensor_ops/accurate_gelu/cpu_kernel.rs | 9 +- src/tensor_ops/accurate_gelu/cuda_kernel.rs | 12 +- src/tensor_ops/adam/adam.cu | 42 ++ src/tensor_ops/adam/cpu_kernel.rs | 54 +- src/tensor_ops/adam/cuda_kernel.rs | 10 +- src/tensor_ops/add/cuda_kernel.rs | 21 +- src/tensor_ops/add/mod.rs | 17 + .../attention_reshape/cuda_kernel.rs | 9 +- src/tensor_ops/axpy/cuda_kernel.rs | 8 +- src/tensor_ops/bce/cuda_kernel.rs | 13 +- src/tensor_ops/choose/cuda_kernel.rs | 9 +- src/tensor_ops/clamp/cuda_kernel.rs | 12 +- src/tensor_ops/cmp/cuda_kernels.rs | 28 +- src/tensor_ops/cmp/mod.rs | 14 + src/tensor_ops/conv2d/cuda_kernel.rs | 14 +- src/tensor_ops/conv2d/cudnn_kernel.rs | 5 +- src/tensor_ops/convtrans2d/cuda_kernel.rs | 13 +- src/tensor_ops/cos/cuda_kernel.rs | 6 +- src/tensor_ops/div/cuda_kernel.rs | 21 +- src/tensor_ops/div/mod.rs | 17 + src/tensor_ops/dropout/cuda_kernel.rs | 9 +- src/tensor_ops/exp/cuda_kernel.rs | 4 + src/tensor_ops/fast_gelu/cuda_kernel.rs | 12 +- src/tensor_ops/huber_error/cuda_kernel.rs | 19 +- src/tensor_ops/ln/cuda_kernel.rs | 6 +- src/tensor_ops/matmul/cpu_kernel.rs | 47 ++ src/tensor_ops/matmul/cuda_kernel.rs | 76 +++ src/tensor_ops/max_to/cuda_kernel.rs | 12 +- src/tensor_ops/maximum/cuda_kernel.rs | 13 +- src/tensor_ops/min_to/cuda_kernel.rs | 12 +- src/tensor_ops/minimum/cuda_kernel.rs | 13 +- src/tensor_ops/mul/cuda_kernel.rs | 21 +- src/tensor_ops/mul/mod.rs | 16 + src/tensor_ops/nans_to/cuda_kernel.rs | 12 +- src/tensor_ops/negate/cuda_kernel.rs | 6 +- src/tensor_ops/pool2d/cuda_kernel.rs | 9 +- src/tensor_ops/pow/cuda_kernel.rs | 11 +- src/tensor_ops/recip/cuda_kernel.rs | 6 +- src/tensor_ops/relu/cuda_kernel.rs | 6 +- src/tensor_ops/rmsprop/cpu_kernel.rs | 73 ++- src/tensor_ops/rmsprop/cuda_kernel.rs | 10 +- src/tensor_ops/rmsprop/rmsprop.cu | 60 ++ src/tensor_ops/roll/cuda_kernel.rs | 11 +- .../select_and_gather/cuda_kernel.rs | 15 +- src/tensor_ops/sgd/cpu_kernel.rs | 57 +- src/tensor_ops/sgd/cuda_kernel.rs | 10 +- src/tensor_ops/sgd/sgd.cu | 40 ++ src/tensor_ops/sigmoid/cuda_kernel.rs | 6 +- src/tensor_ops/sin/cuda_kernel.rs | 6 +- src/tensor_ops/slice/cuda_kernel.rs | 9 +- src/tensor_ops/sqrt/cuda_kernel.rs | 6 +- src/tensor_ops/square/cuda_kernel.rs | 6 +- src/tensor_ops/sub/cuda_kernel.rs | 21 +- src/tensor_ops/sub/mod.rs | 16 + src/tensor_ops/sum_to/cpu_kernel.rs | 70 ++- src/tensor_ops/sum_to/cuda_kernel.rs | 9 +- src/tensor_ops/sum_to/sum_to.cu | 71 +++ src/tensor_ops/tanh/cuda_kernel.rs | 6 +- src/tensor_ops/upscale2d/cuda_kernel.rs | 15 +- src/tensor_ops/utilities/device.rs | 12 +- 67 files changed, 1825 insertions(+), 191 deletions(-) create mode 100644 src/dtypes/amp.rs create mode 100644 src/dtypes/mod.rs diff --git a/Cargo.toml b/Cargo.toml index 26dc2475c..c9c96e7f2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -30,7 +30,7 @@ spin = { version = "0.9.8", default-features = false, features = ["spin_mutex", rand = { version = "0.8.5", default-features = false, features = ["std_rng"] } rand_distr = { version = "0.4.3", default-features = false } zip = { version = "0.6.6", default-features = false, optional = true } -cudarc = { version = "0.9.11", default-features = false, optional = true, features = ["driver", "cublas", "nvrtc"] } +cudarc = { version = "0.9.13", default-features = false, optional = true, features = ["driver", "cublas", "nvrtc"] } num-traits = { version = "0.2.15", default-features = false } safetensors = { version = "0.3", default-features = false, optional = true } memmap2 = { version = "0.5", default-features = false, optional = true } @@ -65,6 +65,7 @@ numpy = ["dep:zip", "std"] safetensors = ["dep:safetensors", "std", "dep:memmap2"] test-f16 = ["f16"] +test-amp-f16 = ["f16"] test-f64 = [] test-integrations = [] ci-check = ["cudarc?/ci-check"] diff --git a/src/dtypes/amp.rs b/src/dtypes/amp.rs new file mode 100644 index 000000000..f5fd0cc5d --- /dev/null +++ b/src/dtypes/amp.rs @@ -0,0 +1,588 @@ +use rand::{distributions::Distribution, Rng}; + +#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +#[repr(transparent)] +pub struct AMP(pub F); + +#[cfg(feature = "f16")] +impl AMP { + pub const INFINITY: Self = AMP(half::f16::INFINITY); + pub const NEG_INFINITY: Self = AMP(half::f16::NEG_INFINITY); +} + +#[cfg(feature = "std")] +impl std::fmt::Display for AMP { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + self.0.fmt(f) + } +} + +impl super::SafeZeros for AMP {} + +#[cfg(feature = "cuda")] +unsafe impl cudarc::driver::ValidAsZeroBits for AMP {} + +#[cfg(feature = "cuda")] +unsafe impl cudarc::driver::DeviceRepr for AMP {} + +#[cfg(feature = "cuda")] +impl cudarc::types::CudaTypeName for AMP { + const NAME: &'static str = F::NAME; +} + +#[cfg(feature = "cudnn")] +impl cudarc::cudnn::CudnnDataType for AMP { + type Scalar = F::Scalar; + const DATA_TYPE: cudarc::cudnn::sys::cudnnDataType_t = F::DATA_TYPE; + fn into_scaling_parameter(self) -> Self::Scalar { + self.0.into_scaling_parameter() + } +} + +impl> std::ops::Add> for AMP { + type Output = Self; + fn add(self, rhs: Self) -> Self::Output { + AMP(self.0 + rhs.0) + } +} + +impl> std::ops::Sub> for AMP { + type Output = Self; + fn sub(self, rhs: Self) -> Self::Output { + AMP(self.0 - rhs.0) + } +} + +impl> std::ops::Mul> for AMP { + type Output = Self; + fn mul(self, rhs: Self) -> Self::Output { + AMP(self.0 * rhs.0) + } +} + +impl> std::ops::Div> for AMP { + type Output = Self; + fn div(self, rhs: Self) -> Self::Output { + AMP(self.0 / rhs.0) + } +} + +impl> std::ops::Rem> for AMP { + type Output = Self; + fn rem(self, rhs: Self) -> Self::Output { + AMP(self.0 % rhs.0) + } +} + +impl> std::ops::Neg for AMP { + type Output = Self; + fn neg(self) -> Self::Output { + AMP(-self.0) + } +} + +impl<'l, 'r, F> std::ops::Add<&'r AMP> for &'l AMP +where + &'l F: std::ops::Add<&'r F, Output = F>, +{ + type Output = AMP; + fn add(self, rhs: &'r AMP) -> Self::Output { + AMP(&self.0 + &rhs.0) + } +} + +impl<'l, 'r, F> std::ops::Sub<&'r AMP> for &'l AMP +where + &'l F: std::ops::Sub<&'r F, Output = F>, +{ + type Output = AMP; + fn sub(self, rhs: &'r AMP) -> Self::Output { + AMP(&self.0 - &rhs.0) + } +} + +impl<'l, 'r, F> std::ops::Mul<&'r AMP> for &'l AMP +where + &'l F: std::ops::Mul<&'r F, Output = F>, +{ + type Output = AMP; + fn mul(self, rhs: &'r AMP) -> Self::Output { + AMP(&self.0 * &rhs.0) + } +} + +impl<'l, 'r, F> std::ops::Div<&'r AMP> for &'l AMP +where + &'l F: std::ops::Div<&'r F, Output = F>, +{ + type Output = AMP; + fn div(self, rhs: &'r AMP) -> Self::Output { + AMP(&self.0 / &rhs.0) + } +} + +impl<'l, 'r, F> std::ops::Rem<&'r AMP> for &'l AMP +where + &'l F: std::ops::Rem<&'r F, Output = F>, +{ + type Output = AMP; + fn rem(self, rhs: &'r AMP) -> Self::Output { + AMP(&self.0 % &rhs.0) + } +} + +impl<'l, F> std::ops::Neg for &'l AMP +where + &'l F: std::ops::Neg, +{ + type Output = AMP; + fn neg(self) -> Self::Output { + AMP(-&self.0) + } +} + +impl> std::ops::AddAssign> for AMP { + fn add_assign(&mut self, rhs: Self) { + self.0 += rhs.0; + } +} + +impl> std::ops::SubAssign> for AMP { + fn sub_assign(&mut self, rhs: Self) { + self.0 -= rhs.0; + } +} + +impl> std::ops::MulAssign> for AMP { + fn mul_assign(&mut self, rhs: Self) { + self.0 *= rhs.0; + } +} + +impl> std::ops::DivAssign> for AMP { + fn div_assign(&mut self, rhs: Self) { + self.0 /= rhs.0; + } +} + +impl num_traits::FromPrimitive for AMP { + fn from_f32(n: f32) -> Option { + F::from_f32(n).map(AMP) + } + fn from_f64(n: f64) -> Option { + F::from_f64(n).map(AMP) + } + fn from_i64(n: i64) -> Option { + F::from_i64(n).map(AMP) + } + fn from_u64(n: u64) -> Option { + F::from_u64(n).map(AMP) + } +} + +impl> num_traits::AsPrimitive for AMP { + fn as_(self) -> T { + self.0.as_() + } +} + +#[cfg(feature = "f16")] +impl num_traits::AsPrimitive> for half::f16 { + fn as_(self) -> AMP { + AMP(self) + } +} + +#[cfg(feature = "f16")] +impl num_traits::AsPrimitive> for f32 { + fn as_(self) -> AMP { + AMP(half::f16::from_f32(self)) + } +} + +#[cfg(feature = "f16")] +impl num_traits::AsPrimitive> for f64 { + fn as_(self) -> AMP { + AMP(half::f16::from_f64(self)) + } +} + +impl num_traits::ToPrimitive for AMP { + fn to_i64(&self) -> Option { + self.0.to_i64() + } + fn to_u64(&self) -> Option { + self.0.to_u64() + } + fn to_f32(&self) -> Option { + self.0.to_f32() + } + fn to_f64(&self) -> Option { + self.0.to_f64() + } +} + +impl crate::shapes::Unit for AMP { + const ONE: Self = AMP(F::ONE); +} + +impl crate::shapes::Dtype for AMP {} + +impl num_traits::Zero for AMP { + fn zero() -> Self { + AMP(F::zero()) + } + fn is_zero(&self) -> bool { + self.0.is_zero() + } +} +impl num_traits::One for AMP { + fn one() -> Self { + AMP(F::one()) + } +} +impl num_traits::Num for AMP { + type FromStrRadixErr = F::FromStrRadixErr; + fn from_str_radix(str: &str, radix: u32) -> Result { + F::from_str_radix(str, radix).map(AMP) + } +} +impl num_traits::NumCast for AMP { + fn from(n: T) -> Option { + F::from(n).map(AMP) + } +} +impl num_traits::FloatConst for AMP { + fn E() -> Self { + AMP(F::E()) + } + + fn FRAC_1_PI() -> Self { + AMP(F::FRAC_1_PI()) + } + + fn FRAC_1_SQRT_2() -> Self { + AMP(F::FRAC_1_SQRT_2()) + } + + fn FRAC_2_PI() -> Self { + AMP(F::FRAC_2_PI()) + } + + fn FRAC_2_SQRT_PI() -> Self { + AMP(F::FRAC_2_SQRT_PI()) + } + + fn FRAC_PI_2() -> Self { + AMP(F::FRAC_PI_2()) + } + + fn FRAC_PI_3() -> Self { + AMP(F::FRAC_PI_3()) + } + + fn FRAC_PI_4() -> Self { + AMP(F::FRAC_PI_4()) + } + + fn FRAC_PI_6() -> Self { + AMP(F::FRAC_PI_6()) + } + + fn FRAC_PI_8() -> Self { + AMP(F::FRAC_PI_8()) + } + + fn LN_10() -> Self { + AMP(F::LN_10()) + } + + fn LN_2() -> Self { + AMP(F::LN_2()) + } + + fn LOG10_E() -> Self { + AMP(F::LOG10_E()) + } + + fn LOG2_E() -> Self { + AMP(F::LOG2_E()) + } + + fn PI() -> Self { + AMP(F::PI()) + } + + fn SQRT_2() -> Self { + AMP(F::SQRT_2()) + } +} +impl num_traits::Float for AMP { + fn nan() -> Self { + AMP(F::nan()) + } + + fn infinity() -> Self { + AMP(F::infinity()) + } + + fn neg_infinity() -> Self { + AMP(F::neg_infinity()) + } + + fn neg_zero() -> Self { + AMP(F::neg_zero()) + } + + fn min_value() -> Self { + AMP(F::min_value()) + } + + fn min_positive_value() -> Self { + AMP(F::min_positive_value()) + } + + fn max_value() -> Self { + AMP(F::max_value()) + } + + fn is_nan(self) -> bool { + self.0.is_nan() + } + + fn is_infinite(self) -> bool { + self.0.is_infinite() + } + + fn is_finite(self) -> bool { + self.0.is_finite() + } + + fn is_normal(self) -> bool { + self.0.is_normal() + } + + fn classify(self) -> core::num::FpCategory { + self.0.classify() + } + + fn floor(self) -> Self { + AMP(self.0.floor()) + } + + fn ceil(self) -> Self { + AMP(self.0.ceil()) + } + + fn round(self) -> Self { + AMP(self.0.round()) + } + + fn trunc(self) -> Self { + AMP(self.0.trunc()) + } + + fn fract(self) -> Self { + AMP(self.0.fract()) + } + + fn abs(self) -> Self { + AMP(self.0.abs()) + } + + fn signum(self) -> Self { + AMP(self.0.signum()) + } + + fn is_sign_positive(self) -> bool { + self.0.is_sign_positive() + } + + fn is_sign_negative(self) -> bool { + self.0.is_sign_negative() + } + + fn mul_add(self, a: Self, b: Self) -> Self { + AMP(self.0.mul_add(a.0, b.0)) + } + + fn recip(self) -> Self { + AMP(self.0.recip()) + } + + fn powi(self, n: i32) -> Self { + AMP(self.0.powi(n)) + } + + fn powf(self, n: Self) -> Self { + AMP(self.0.powf(n.0)) + } + + fn sqrt(self) -> Self { + AMP(self.0.sqrt()) + } + + fn exp(self) -> Self { + AMP(self.0.exp()) + } + + fn exp2(self) -> Self { + AMP(self.0.exp2()) + } + + fn ln(self) -> Self { + AMP(self.0.ln()) + } + + fn log(self, base: Self) -> Self { + AMP(self.0.log(base.0)) + } + + fn log2(self) -> Self { + AMP(self.0.log2()) + } + + fn log10(self) -> Self { + AMP(self.0.log10()) + } + + fn max(self, other: Self) -> Self { + AMP(self.0.max(other.0)) + } + + fn min(self, other: Self) -> Self { + AMP(self.0.min(other.0)) + } + + fn abs_sub(self, other: Self) -> Self { + AMP(self.0.abs_sub(other.0)) + } + + fn cbrt(self) -> Self { + AMP(self.0.cbrt()) + } + + fn hypot(self, other: Self) -> Self { + AMP(self.0.hypot(other.0)) + } + + fn sin(self) -> Self { + AMP(self.0.sin()) + } + + fn cos(self) -> Self { + AMP(self.0.cos()) + } + + fn tan(self) -> Self { + AMP(self.0.tan()) + } + + fn asin(self) -> Self { + AMP(self.0.asin()) + } + + fn acos(self) -> Self { + AMP(self.0.acos()) + } + + fn atan(self) -> Self { + AMP(self.0.atan()) + } + + fn atan2(self, other: Self) -> Self { + AMP(self.0.atan2(other.0)) + } + + fn sin_cos(self) -> (Self, Self) { + let (a, b) = self.0.sin_cos(); + (AMP(a), AMP(b)) + } + + fn exp_m1(self) -> Self { + AMP(self.0.exp_m1()) + } + + fn ln_1p(self) -> Self { + AMP(self.0.ln_1p()) + } + + fn sinh(self) -> Self { + AMP(self.0.sinh()) + } + + fn cosh(self) -> Self { + AMP(self.0.cosh()) + } + + fn tanh(self) -> Self { + AMP(self.0.tanh()) + } + + fn asinh(self) -> Self { + AMP(self.0.asinh()) + } + + fn acosh(self) -> Self { + AMP(self.0.acosh()) + } + + fn atanh(self) -> Self { + AMP(self.0.atanh()) + } + + fn integer_decode(self) -> (u64, i16, i8) { + self.0.integer_decode() + } +} + +macro_rules! impl_distribution { + ($Distr:ty) => { + impl Distribution> for $Distr + where + Self: Distribution, + { + fn sample(&self, rng: &mut R) -> AMP { + AMP(>::sample(self, rng)) + } + } + }; +} + +impl_distribution!(rand_distr::Standard); +impl_distribution!(rand_distr::StandardNormal); +impl_distribution!(rand_distr::Exp1); +impl_distribution!(rand_distr::Open01); +impl_distribution!(rand_distr::OpenClosed01); + +#[derive(Debug, Clone, Copy)] +pub struct AMPSampler(F::Sampler); + +impl rand_distr::uniform::SampleUniform for AMP { + type Sampler = AMPSampler; +} + +impl rand_distr::uniform::UniformSampler for AMPSampler { + type X = AMP; + fn new(low: B1, high: B2) -> Self + where + B1: rand_distr::uniform::SampleBorrow + Sized, + B2: rand_distr::uniform::SampleBorrow + Sized, + { + let l = low.borrow(); + let h = high.borrow(); + Self(F::Sampler::new(&l.0, &h.0)) + } + fn new_inclusive(low: B1, high: B2) -> Self + where + B1: rand_distr::uniform::SampleBorrow + Sized, + B2: rand_distr::uniform::SampleBorrow + Sized, + { + let l = low.borrow(); + let h = high.borrow(); + Self(F::Sampler::new_inclusive(&l.0, &h.0)) + } + fn sample(&self, rng: &mut R) -> Self::X { + AMP(self.0.sample(rng)) + } +} diff --git a/src/dtypes/mod.rs b/src/dtypes/mod.rs new file mode 100644 index 000000000..3a17d6032 --- /dev/null +++ b/src/dtypes/mod.rs @@ -0,0 +1,119 @@ +mod amp; + +pub use amp::AMP; + +#[cfg(feature = "f16")] +pub use half::f16; + +#[cfg(not(feature = "cuda"))] +pub trait SafeZeros {} + +#[cfg(feature = "cuda")] +pub trait SafeZeros: cudarc::driver::ValidAsZeroBits + cudarc::driver::DeviceRepr {} + +/// Represents a unit type, but no arithmetic. +pub trait Unit: + 'static + + Copy + + Clone + + Default + + std::fmt::Debug + + PartialEq + + PartialOrd + + Send + + Sync + + std::marker::Unpin + + SafeZeros +{ + const ONE: Self; +} + +macro_rules! unit { + ($type:ty, $one:expr) => { + impl SafeZeros for $type {} + impl Unit for $type { + const ONE: Self = $one; + } + }; +} + +unit!(f32, 1.0); +unit!(f64, 1.0); +unit!(usize, 1); +unit!(isize, 1); +unit!(u8, 1); +unit!(i8, 1); +unit!(u16, 1); +unit!(i16, 1); +unit!(u32, 1); +unit!(i32, 1); +unit!(u64, 1); +unit!(i64, 1); +unit!(u128, 1); +unit!(i128, 1); +unit!(bool, true); +#[cfg(feature = "f16")] +unit!(f16, f16::ONE); + +/// Represents something that has a [Unit]. +pub trait HasUnitType { + type Unit: Unit; +} + +/// Represents a data type or element of an array that can have +/// arithmatic operations applied to it. The main difference +/// between [Dtype] and [Unit] is that [`bool`] is [Unit], but +/// not [Dtype]. +pub trait Dtype: + Unit + + std::ops::Add + + std::ops::Sub + + std::ops::Mul + + std::ops::Div + + std::ops::AddAssign + + std::ops::SubAssign + + std::ops::MulAssign + + std::ops::DivAssign + + num_traits::FromPrimitive + + num_traits::ToPrimitive +{ +} +impl Dtype for f32 {} +impl Dtype for f64 {} +impl Dtype for i8 {} +impl Dtype for i16 {} +impl Dtype for i32 {} +impl Dtype for i64 {} +impl Dtype for i128 {} +impl Dtype for isize {} +impl Dtype for u8 {} +impl Dtype for u16 {} +impl Dtype for u32 {} +impl Dtype for u64 {} +impl Dtype for u128 {} +impl Dtype for usize {} +#[cfg(feature = "f16")] +impl Dtype for f16 {} + +/// Represents something that has a [Dtype]. +pub trait HasDtype { + type Dtype: Dtype; +} + +pub trait NotMixedPrecision {} +impl NotMixedPrecision for f32 {} +impl NotMixedPrecision for f64 {} +impl NotMixedPrecision for i8 {} +impl NotMixedPrecision for i16 {} +impl NotMixedPrecision for i32 {} +impl NotMixedPrecision for i64 {} +impl NotMixedPrecision for i128 {} +impl NotMixedPrecision for isize {} +impl NotMixedPrecision for u8 {} +impl NotMixedPrecision for u16 {} +impl NotMixedPrecision for u32 {} +impl NotMixedPrecision for u64 {} +impl NotMixedPrecision for u128 {} +impl NotMixedPrecision for usize {} +#[cfg(feature = "f16")] +impl NotMixedPrecision for f16 {} diff --git a/src/lib.rs b/src/lib.rs index 60915a9d6..026502a13 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -183,6 +183,7 @@ extern crate alloc; extern crate no_std_compat as std; pub mod data; +pub mod dtypes; pub mod feature_flags; pub mod losses; pub mod nn; @@ -252,7 +253,11 @@ pub(crate) mod tests { #[cfg(all(feature = "test-f64", feature = "test-f16"))] compile_error!("f64 and f16 cannot be tested at the same time"); - #[cfg(all(not(feature = "test-f16"), not(feature = "test-f64")))] + #[cfg(all( + not(feature = "test-amp-f16"), + not(feature = "test-f16"), + not(feature = "test-f64") + ))] pub type TestDtype = f32; #[cfg(feature = "test-f16")] @@ -261,6 +266,9 @@ pub(crate) mod tests { #[cfg(feature = "test-f64")] pub type TestDtype = f64; + #[cfg(feature = "test-amp-f16")] + pub type TestDtype = crate::dtypes::AMP; + pub trait AssertClose { type Elem: std::fmt::Display + std::fmt::Debug + Copy; const DEFAULT_TOLERANCE: Self::Elem; @@ -282,6 +290,22 @@ pub(crate) mod tests { } } + impl AssertClose + for crate::dtypes::AMP + { + type Elem = crate::dtypes::AMP; + const DEFAULT_TOLERANCE: Self::Elem = crate::dtypes::AMP(F::DEFAULT_TOLERANCE); + fn get_far_pair( + &self, + rhs: &Self, + tolerance: Self::Elem, + ) -> Option<(Self::Elem, Self::Elem)> { + self.0 + .get_far_pair(&rhs.0, tolerance.0) + .map(|(l, r)| (crate::dtypes::AMP(l), crate::dtypes::AMP(r))) + } + } + #[cfg(feature = "f16")] impl AssertClose for half::f16 { type Elem = Self; diff --git a/src/shapes/mod.rs b/src/shapes/mod.rs index 4c0cf0009..03a2d91a4 100644 --- a/src/shapes/mod.rs +++ b/src/shapes/mod.rs @@ -31,5 +31,6 @@ pub use slice::SliceShape; pub use axes::{Axes, Axes2, Axes3, Axes4, Axes5, Axes6, Axis, HasAxes}; pub use shape::{Array, Const, ConstDim, Dim}; pub use shape::{ConstShape, HasShape, Shape}; -pub use shape::{Dtype, HasDtype, HasUnitType, Unit}; pub use shape::{Rank0, Rank1, Rank2, Rank3, Rank4, Rank5, Rank6}; + +pub use crate::dtypes::{Dtype, HasDtype, HasUnitType, SafeZeros, Unit}; diff --git a/src/shapes/shape.rs b/src/shapes/shape.rs index 1d900b339..184337cde 100644 --- a/src/shapes/shape.rs +++ b/src/shapes/shape.rs @@ -1,103 +1,5 @@ use super::{axes::*, ReduceShape, ReduceShapeTo}; -#[cfg(feature = "f16")] -pub use half::f16; - -#[cfg(not(feature = "cuda"))] -pub trait SafeZeros {} - -#[cfg(feature = "cuda")] -pub trait SafeZeros: cudarc::driver::ValidAsZeroBits + cudarc::driver::DeviceRepr {} - -/// Represents a unit type, but no arithmetic. -pub trait Unit: - 'static - + Copy - + Clone - + Default - + std::fmt::Debug - + PartialEq - + PartialOrd - + Send - + Sync - + std::marker::Unpin - + SafeZeros -{ - const ONE: Self; -} - -macro_rules! unit { - ($type:ty, $one:expr) => { - impl SafeZeros for $type {} - impl Unit for $type { - const ONE: Self = $one; - } - }; -} - -unit!(f32, 1.0); -unit!(f64, 1.0); -unit!(usize, 1); -unit!(isize, 1); -unit!(u8, 1); -unit!(i8, 1); -unit!(u16, 1); -unit!(i16, 1); -unit!(u32, 1); -unit!(i32, 1); -unit!(u64, 1); -unit!(i64, 1); -unit!(u128, 1); -unit!(i128, 1); -unit!(bool, true); -#[cfg(feature = "f16")] -unit!(f16, f16::ONE); - -/// Represents something that has a [Unit]. -pub trait HasUnitType { - type Unit: Unit; -} - -/// Represents a data type or element of an array that can have -/// arithmatic operations applied to it. The main difference -/// between [Dtype] and [Unit] is that [`bool`] is [Unit], but -/// not [Dtype]. -pub trait Dtype: - Unit - + std::ops::Add - + std::ops::Sub - + std::ops::Mul - + std::ops::Div - + std::ops::AddAssign - + std::ops::SubAssign - + std::ops::MulAssign - + std::ops::DivAssign - + num_traits::FromPrimitive - + num_traits::ToPrimitive -{ -} -impl Dtype for f32 {} -impl Dtype for f64 {} -impl Dtype for i8 {} -impl Dtype for i16 {} -impl Dtype for i32 {} -impl Dtype for i64 {} -impl Dtype for i128 {} -impl Dtype for isize {} -impl Dtype for u8 {} -impl Dtype for u16 {} -impl Dtype for u32 {} -impl Dtype for u64 {} -impl Dtype for u128 {} -impl Dtype for usize {} -#[cfg(feature = "f16")] -impl Dtype for f16 {} - -/// Represents something that has a [Dtype]. -pub trait HasDtype { - type Dtype: Dtype; -} - /// Represents a single dimension of a multi dimensional [Shape] pub trait Dim: 'static + Copy + Clone + std::fmt::Debug + Send + Sync + Eq + PartialEq { fn size(&self) -> usize; diff --git a/src/tensor_ops/abs/cuda_kernel.rs b/src/tensor_ops/abs/cuda_kernel.rs index 773da7c54..6b5fc9f12 100644 --- a/src/tensor_ops/abs/cuda_kernel.rs +++ b/src/tensor_ops/abs/cuda_kernel.rs @@ -1,4 +1,6 @@ use super::AbsKernelOp; +#[allow(unused_imports)] +use crate::dtypes::*; use crate::tensor_ops::cuda_kernels::cuda_unary; unsafe impl cudarc::driver::DeviceRepr for AbsKernelOp {} @@ -6,6 +8,8 @@ unsafe impl cudarc::driver::DeviceRepr for AbsKernelOp {} const PTX: &str = include_str!(concat!(env!("OUT_DIR"), "/abs.ptx")); #[cfg(feature = "f16")] -cuda_unary!(AbsKernelOp, half::f16, PTX, "abs_fwd_f16", "abs_bwd_f16"); +cuda_unary!(AbsKernelOp, AMP, PTX, "abs_fwd_f16", "abs_bwd_f16"); +#[cfg(feature = "f16")] +cuda_unary!(AbsKernelOp, f16, PTX, "abs_fwd_f16", "abs_bwd_f16"); cuda_unary!(AbsKernelOp, f32, PTX, "abs_fwd_f32", "abs_bwd_f32"); cuda_unary!(AbsKernelOp, f64, PTX, "abs_fwd_f64", "abs_bwd_f64"); diff --git a/src/tensor_ops/accurate_gelu/cpu_kernel.rs b/src/tensor_ops/accurate_gelu/cpu_kernel.rs index 334a3a102..0531c9081 100644 --- a/src/tensor_ops/accurate_gelu/cpu_kernel.rs +++ b/src/tensor_ops/accurate_gelu/cpu_kernel.rs @@ -8,10 +8,17 @@ trait Erf { fn erf(self) -> Self; } +#[cfg(feature = "f16")] +impl Erf for crate::dtypes::AMP { + fn erf(self) -> Self { + crate::dtypes::AMP(f16::from_f32(erff(self.0.to_f32()))) + } +} + #[cfg(feature = "f16")] impl Erf for f16 { fn erf(self) -> Self { - f16::from_f32(erff(f16::to_f32(self))) + f16::from_f32(erff(self.to_f32())) } } diff --git a/src/tensor_ops/accurate_gelu/cuda_kernel.rs b/src/tensor_ops/accurate_gelu/cuda_kernel.rs index 460c21990..d651da3a9 100644 --- a/src/tensor_ops/accurate_gelu/cuda_kernel.rs +++ b/src/tensor_ops/accurate_gelu/cuda_kernel.rs @@ -1,4 +1,6 @@ use super::AccurateGeLUKernelOp; +#[allow(unused_imports)] +use crate::dtypes::*; use crate::tensor_ops::cuda_kernels::cuda_unary; unsafe impl cudarc::driver::DeviceRepr for super::AccurateGeLUKernelOp {} @@ -8,7 +10,15 @@ const PTX: &str = include_str!(concat!(env!("OUT_DIR"), "/accurate_gelu.ptx")); #[cfg(feature = "f16")] cuda_unary!( AccurateGeLUKernelOp, - half::f16, + AMP, + PTX, + "accurate_gelu_fwd_f16", + "accurate_gelu_bwd_f16" +); +#[cfg(feature = "f16")] +cuda_unary!( + AccurateGeLUKernelOp, + f16, PTX, "accurate_gelu_fwd_f16", "accurate_gelu_bwd_f16" diff --git a/src/tensor_ops/adam/adam.cu b/src/tensor_ops/adam/adam.cu index b5ee7268a..53d996be1 100644 --- a/src/tensor_ops/adam/adam.cu +++ b/src/tensor_ops/adam/adam.cu @@ -74,3 +74,45 @@ extern "C" __global__ void FN( \ ADAM(__half, adam_update_f16); ADAM(float, adam_update_f32); ADAM(double, adam_update_f64); + +extern "C" __global__ void adam_update_amp_f16( + const AdamConfig cfg, + const size_t numel, + const int t_int, + __half* param, + __half* moment1, + __half* moment2, + const __half* grad +) { + float beta1 = cfg.beta1; + float beta2 = cfg.beta2; + float lr = cfg.lr; + float weight_decay = cfg.weight_decay; + float eps = cfg.eps; + float one = 1.0; + float t = t_int; + for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { + float p = param[i]; + float g = grad[i]; + float m = moment1[i]; + float v = moment2[i]; + + if (cfg.weight_decay_type == L2) { + g += weight_decay * p; + } + + m = m * beta1 + g * (one - beta1); + v = v * beta2 + g * g * (one - beta2); + float m_hat = m * one / (one - powg(beta1, t)); + float v_hat = v * one / (one - powg(beta2, t)); + g = lr * m_hat / (sqrtg(v_hat) + eps); + + if (cfg.weight_decay_type == Decoupled) { + g += (weight_decay * lr) * p; + } + + moment1[i] = m; + moment2[i] = v; + param[i] -= g; + } +} \ No newline at end of file diff --git a/src/tensor_ops/adam/cpu_kernel.rs b/src/tensor_ops/adam/cpu_kernel.rs index 4abed5354..b89c0f4d2 100644 --- a/src/tensor_ops/adam/cpu_kernel.rs +++ b/src/tensor_ops/adam/cpu_kernel.rs @@ -1,7 +1,57 @@ use super::{AdamConfig, AdamKernel, WeightDecay}; -use crate::{shapes::Dtype, tensor::Cpu}; +use crate::{ + dtypes::{Dtype, NotMixedPrecision}, + tensor::Cpu, +}; -impl AdamKernel for Cpu { +#[cfg(feature = "f16")] +impl AdamKernel> for Cpu { + fn adam_kernel( + &self, + t: i32, + cfg: &AdamConfig, + param: &mut Self::Vec, + moment1: &mut Self::Vec, + moment2: &mut Self::Vec, + grad: &Self::Vec, + ) -> Result<(), Self::Err> { + let betas = cfg.betas.map(|x| x as f32); + let eps = cfg.eps as f32; + let lr = cfg.lr as f32; + + for ((p, g), (m, v)) in param + .iter_mut() + .zip(grad.iter().cloned()) + .zip(moment1.iter_mut().zip(moment2.iter_mut())) + { + let p_f32 = p.0.to_f32(); + let mut g_f32 = g.0.to_f32(); + let mut m_f32 = m.0.to_f32(); + let mut v_f32 = v.0.to_f32(); + + if let Some(WeightDecay::L2(wd)) = cfg.weight_decay { + g_f32 += (wd as f32) * p_f32; + } + + m_f32 = m_f32 * betas[0] + g_f32 * (1.0 - betas[0]); + v_f32 = v_f32 * betas[1] + g_f32.powi(2) * (1.0 - betas[1]); + let m_hat = m_f32 * (1.0 - betas[0].powi(t)).recip(); + let v_hat = v_f32 * (1.0 - betas[1].powi(t)).recip(); + g_f32 = lr * m_hat / (v_hat.sqrt() + eps); + + if let Some(WeightDecay::Decoupled(wd)) = cfg.weight_decay { + g_f32 += (wd * cfg.lr) as f32 * p_f32; + } + + p.0 = crate::dtypes::f16::from_f32(p_f32 - g_f32); + m.0 = crate::dtypes::f16::from_f32(m_f32); + v.0 = crate::dtypes::f16::from_f32(v_f32); + } + Ok(()) + } +} + +impl AdamKernel for Cpu { fn adam_kernel( &self, t: i32, diff --git a/src/tensor_ops/adam/cuda_kernel.rs b/src/tensor_ops/adam/cuda_kernel.rs index 919333271..3d1bfb9c7 100644 --- a/src/tensor_ops/adam/cuda_kernel.rs +++ b/src/tensor_ops/adam/cuda_kernel.rs @@ -1,5 +1,5 @@ use crate::{ - shapes::*, + dtypes::*, tensor::{launch_cfg, Cuda}, tensor_ops::optim::*, }; @@ -39,7 +39,13 @@ trait HasCudaKernel { } #[cfg(feature = "f16")] -impl HasCudaKernel for Cuda { +impl HasCudaKernel> for Cuda { + const MOD: &'static str = "adam_amp_f16"; + const FWD: &'static str = "adam_update_amp_f16"; +} + +#[cfg(feature = "f16")] +impl HasCudaKernel for Cuda { const MOD: &'static str = "adam_f16"; const FWD: &'static str = "adam_update_f16"; } diff --git a/src/tensor_ops/add/cuda_kernel.rs b/src/tensor_ops/add/cuda_kernel.rs index a2b399643..d5bfa45f0 100644 --- a/src/tensor_ops/add/cuda_kernel.rs +++ b/src/tensor_ops/add/cuda_kernel.rs @@ -1,8 +1,12 @@ use super::{BinaryAddKernelOp as Binary, ScalarAddKernelOp as Scalar}; +#[allow(unused_imports)] +use crate::dtypes::*; use crate::tensor_ops::cuda_kernels::{cuda_binary, cuda_unary}; #[cfg(feature = "f16")] -unsafe impl cudarc::driver::DeviceRepr for Scalar {} +unsafe impl cudarc::driver::DeviceRepr for Scalar {} +#[cfg(feature = "f16")] +unsafe impl cudarc::driver::DeviceRepr for Scalar> {} unsafe impl cudarc::driver::DeviceRepr for Scalar {} unsafe impl cudarc::driver::DeviceRepr for Scalar {} unsafe impl cudarc::driver::DeviceRepr for Binary {} @@ -11,13 +15,24 @@ const SCALAR_PTX: &str = include_str!(concat!(env!("OUT_DIR"), "/scalar_add.ptx" const BINARY_PTX: &str = include_str!(concat!(env!("OUT_DIR"), "/binary_add.ptx")); #[cfg(feature = "f16")] -cuda_unary!(const_df() Scalar, half::f16, SCALAR_PTX, "sadd_fwd_f16", "sadd_bwd_f16"); +cuda_unary!(const_df() Scalar>, AMP, SCALAR_PTX, "sadd_fwd_f16", "sadd_bwd_f16"); +#[cfg(feature = "f16")] +cuda_unary!(const_df() Scalar, f16, SCALAR_PTX, "sadd_fwd_f16", "sadd_bwd_f16"); cuda_unary!(const_df() Scalar, f32, SCALAR_PTX, "sadd_fwd_f32", "sadd_bwd_f32"); cuda_unary!(const_df() Scalar, f64, SCALAR_PTX, "sadd_fwd_f64", "sadd_bwd_f64"); #[cfg(feature = "f16")] cuda_binary!( const_df() Binary, - half::f16, + AMP, + BINARY_PTX, + "badd_fwd_f16", + "badd_bwd_lhs_f16", + "badd_bwd_rhs_f16" +); +#[cfg(feature = "f16")] +cuda_binary!( + const_df() Binary, + f16, BINARY_PTX, "badd_fwd_f16", "badd_bwd_lhs_f16", diff --git a/src/tensor_ops/add/mod.rs b/src/tensor_ops/add/mod.rs index 28b886a37..7779a5200 100644 --- a/src/tensor_ops/add/mod.rs +++ b/src/tensor_ops/add/mod.rs @@ -85,6 +85,23 @@ impl, half::f16>, T: Tape< } } +#[cfg(feature = "f16")] +impl< + S: Shape, + D: UnaryKernel< + ScalarAddKernelOp>, + crate::dtypes::AMP, + >, + T: Tape, D>, + > TryAdd for Tensor, D, T> +{ + /// See [add] + fn try_add(self, rhs: f32) -> Result { + let scalar = crate::dtypes::AMP(half::f16::from_f32(rhs)); + try_unary_op(ScalarAddKernelOp { scalar }, self) + } +} + impl, LhsTape: Tape, Rhs> std::ops::Add for Tensor where diff --git a/src/tensor_ops/attention_reshape/cuda_kernel.rs b/src/tensor_ops/attention_reshape/cuda_kernel.rs index b4a7b37b1..9f34490ed 100644 --- a/src/tensor_ops/attention_reshape/cuda_kernel.rs +++ b/src/tensor_ops/attention_reshape/cuda_kernel.rs @@ -1,4 +1,6 @@ use super::*; +#[allow(unused_imports)] +use crate::dtypes::*; use crate::tensor::cuda::Cuda; use cudarc::driver::{DeviceRepr, LaunchAsync}; @@ -20,7 +22,12 @@ trait HasCudaKernel { } #[cfg(feature = "f16")] -impl HasCudaKernel for Cuda { +impl HasCudaKernel> for Cuda { + const FN: &'static str = "attention_reshape_f16"; +} + +#[cfg(feature = "f16")] +impl HasCudaKernel for Cuda { const FN: &'static str = "attention_reshape_f16"; } diff --git a/src/tensor_ops/axpy/cuda_kernel.rs b/src/tensor_ops/axpy/cuda_kernel.rs index 5484f4441..5a2abf0c9 100644 --- a/src/tensor_ops/axpy/cuda_kernel.rs +++ b/src/tensor_ops/axpy/cuda_kernel.rs @@ -1,5 +1,5 @@ use crate::{ - shapes::*, + dtypes::*, tensor::{launch_cfg, Cuda}, }; @@ -11,7 +11,11 @@ trait HasCudaKernel { const FN: &'static str; } #[cfg(feature = "f16")] -impl HasCudaKernel for Cuda { +impl HasCudaKernel> for Cuda { + const FN: &'static str = "axpy_f16"; +} +#[cfg(feature = "f16")] +impl HasCudaKernel for Cuda { const FN: &'static str = "axpy_f16"; } impl HasCudaKernel for Cuda { diff --git a/src/tensor_ops/bce/cuda_kernel.rs b/src/tensor_ops/bce/cuda_kernel.rs index f55d2ac43..1f705d17c 100644 --- a/src/tensor_ops/bce/cuda_kernel.rs +++ b/src/tensor_ops/bce/cuda_kernel.rs @@ -1,4 +1,6 @@ use super::BCEKernelOp; +#[allow(unused_imports)] +use crate::dtypes::*; use crate::tensor_ops::cuda_kernels::cuda_binary; unsafe impl cudarc::driver::DeviceRepr for BCEKernelOp {} @@ -8,7 +10,16 @@ const PTX: &str = include_str!(concat!(env!("OUT_DIR"), "/bce.ptx")); #[cfg(feature = "f16")] cuda_binary!( BCEKernelOp, - half::f16, + AMP, + PTX, + "bce_fwd_f16", + "bce_bwd_lhs_f16", + "bce_bwd_rhs_f16" +); +#[cfg(feature = "f16")] +cuda_binary!( + BCEKernelOp, + f16, PTX, "bce_fwd_f16", "bce_bwd_lhs_f16", diff --git a/src/tensor_ops/choose/cuda_kernel.rs b/src/tensor_ops/choose/cuda_kernel.rs index 1ef7dc11d..ed29149bb 100644 --- a/src/tensor_ops/choose/cuda_kernel.rs +++ b/src/tensor_ops/choose/cuda_kernel.rs @@ -1,4 +1,5 @@ use crate::{ + dtypes::*, shapes::*, tensor::{launch_cfg, Cuda, Storage, Tensor}, }; @@ -12,7 +13,13 @@ pub(crate) trait HasCudaKernel { } #[cfg(feature = "f16")] -impl HasCudaKernel for Cuda { +impl HasCudaKernel> for Cuda { + const MOD: &'static str = "choose_f16"; + const FNS: &'static [&'static str] = &["choose_fwd_f16", "choose_bwd_f16"]; +} + +#[cfg(feature = "f16")] +impl HasCudaKernel for Cuda { const MOD: &'static str = "choose_f16"; const FNS: &'static [&'static str] = &["choose_fwd_f16", "choose_bwd_f16"]; } diff --git a/src/tensor_ops/clamp/cuda_kernel.rs b/src/tensor_ops/clamp/cuda_kernel.rs index c8abf5248..73b18f535 100644 --- a/src/tensor_ops/clamp/cuda_kernel.rs +++ b/src/tensor_ops/clamp/cuda_kernel.rs @@ -1,8 +1,12 @@ use super::ClampKernelOp; +#[allow(unused_imports)] +use crate::dtypes::*; use crate::tensor_ops::cuda_kernels::cuda_unary; #[cfg(feature = "f16")] -unsafe impl cudarc::driver::DeviceRepr for ClampKernelOp {} +unsafe impl cudarc::driver::DeviceRepr for ClampKernelOp> {} +#[cfg(feature = "f16")] +unsafe impl cudarc::driver::DeviceRepr for ClampKernelOp {} unsafe impl cudarc::driver::DeviceRepr for ClampKernelOp {} unsafe impl cudarc::driver::DeviceRepr for ClampKernelOp {} @@ -10,11 +14,13 @@ const P: &str = include_str!(concat!(env!("OUT_DIR"), "/clamp.ptx")); #[cfg(feature = "f16")] cuda_unary!( - ClampKernelOp, - half::f16, + ClampKernelOp>, + AMP, P, "clamp_fwd_f16", "clamp_bwd_f16" ); +#[cfg(feature = "f16")] +cuda_unary!(ClampKernelOp, f16, P, "clamp_fwd_f16", "clamp_bwd_f16"); cuda_unary!(ClampKernelOp, f32, P, "clamp_fwd_f32", "clamp_bwd_f32"); cuda_unary!(ClampKernelOp, f64, P, "clamp_fwd_f64", "clamp_bwd_f64"); diff --git a/src/tensor_ops/cmp/cuda_kernels.rs b/src/tensor_ops/cmp/cuda_kernels.rs index 8b31b21c8..31d0d1196 100644 --- a/src/tensor_ops/cmp/cuda_kernels.rs +++ b/src/tensor_ops/cmp/cuda_kernels.rs @@ -1,5 +1,6 @@ use crate::{ - shapes::{Shape, Unit}, + dtypes::*, + shapes::Shape, tensor::{launch_cfg, Cuda, Tensor}, }; use cudarc::driver::{CudaSlice, LaunchAsync}; @@ -129,17 +130,30 @@ macro_rules! cmps { } #[cfg(feature = "f16")] -cmps!(EqKernelOp, half::f16, "eq_fwd_f16", "scalar_eq_fwd_f16"); +cmps!(EqKernelOp, AMP, "eq_fwd_f16", "scalar_eq_fwd_f16"); #[cfg(feature = "f16")] -cmps!(NeKernelOp, half::f16, "ne_fwd_f16", "scalar_ne_fwd_f16"); +cmps!(NeKernelOp, AMP, "ne_fwd_f16", "scalar_ne_fwd_f16"); #[cfg(feature = "f16")] -cmps!(GtKernelOp, half::f16, "gt_fwd_f16", "scalar_gt_fwd_f16"); +cmps!(GtKernelOp, AMP, "gt_fwd_f16", "scalar_gt_fwd_f16"); #[cfg(feature = "f16")] -cmps!(GeKernelOp, half::f16, "ge_fwd_f16", "scalar_ge_fwd_f16"); +cmps!(GeKernelOp, AMP, "ge_fwd_f16", "scalar_ge_fwd_f16"); #[cfg(feature = "f16")] -cmps!(LtKernelOp, half::f16, "lt_fwd_f16", "scalar_lt_fwd_f16"); +cmps!(LtKernelOp, AMP, "lt_fwd_f16", "scalar_lt_fwd_f16"); #[cfg(feature = "f16")] -cmps!(LeKernelOp, half::f16, "le_fwd_f16", "scalar_le_fwd_f16"); +cmps!(LeKernelOp, AMP, "le_fwd_f16", "scalar_le_fwd_f16"); + +#[cfg(feature = "f16")] +cmps!(EqKernelOp, f16, "eq_fwd_f16", "scalar_eq_fwd_f16"); +#[cfg(feature = "f16")] +cmps!(NeKernelOp, f16, "ne_fwd_f16", "scalar_ne_fwd_f16"); +#[cfg(feature = "f16")] +cmps!(GtKernelOp, f16, "gt_fwd_f16", "scalar_gt_fwd_f16"); +#[cfg(feature = "f16")] +cmps!(GeKernelOp, f16, "ge_fwd_f16", "scalar_ge_fwd_f16"); +#[cfg(feature = "f16")] +cmps!(LtKernelOp, f16, "lt_fwd_f16", "scalar_lt_fwd_f16"); +#[cfg(feature = "f16")] +cmps!(LeKernelOp, f16, "le_fwd_f16", "scalar_le_fwd_f16"); cmps!(EqKernelOp, f32, "eq_fwd_f32", "scalar_eq_fwd_f32"); cmps!(NeKernelOp, f32, "ne_fwd_f32", "scalar_ne_fwd_f32"); diff --git a/src/tensor_ops/cmp/mod.rs b/src/tensor_ops/cmp/mod.rs index 60d4b5af9..b499ef559 100644 --- a/src/tensor_ops/cmp/mod.rs +++ b/src/tensor_ops/cmp/mod.rs @@ -239,6 +239,20 @@ macro_rules! impl_cmp_kernel_op { } } + #[cfg(feature = "f16")] + impl< + S: Shape, + D: ScalarCmpKernel<$KernelOp, crate::dtypes::AMP>, + T: Tape, D>, + > $TraitName for Tensor, D, T> + { + type Output = Tensor; + #[doc = $doc] + fn $TryFnName(&self, other: f32) -> Result { + try_scalar_cmp_op(self, crate::dtypes::AMP(half::f16::from_f32(other))) + } + } + impl, T: Tape> Tensor { #[doc = $doc] #[deprecated = "You can now use the non-scalar method for both tensors & scalars."] diff --git a/src/tensor_ops/conv2d/cuda_kernel.rs b/src/tensor_ops/conv2d/cuda_kernel.rs index 759437e2d..4836ed65c 100644 --- a/src/tensor_ops/conv2d/cuda_kernel.rs +++ b/src/tensor_ops/conv2d/cuda_kernel.rs @@ -2,6 +2,7 @@ use cudarc::cublas::{CudaBlas, Gemm}; use cudarc::driver::{DeviceRepr, LaunchAsync, ValidAsZeroBits}; use crate::{ + dtypes::*, shapes::*, tensor::{launch_cfg, Cuda, Tensor, Tensorlike}, }; @@ -18,7 +19,18 @@ trait HasCudaKernel { } #[cfg(feature = "f16")] -impl HasCudaKernel for Cuda { +impl HasCudaKernel> for Cuda { + const MOD: &'static str = "conv2d_f16"; + const FNS: &'static [&'static str] = &[ + "unfold_input_into_patches_f16", + "unfold_output_into_patches_f16", + "transpose_filters_f16", + "sum_transposed_filters_f16", + ]; +} + +#[cfg(feature = "f16")] +impl HasCudaKernel for Cuda { const MOD: &'static str = "conv2d_f16"; const FNS: &'static [&'static str] = &[ "unfold_input_into_patches_f16", diff --git a/src/tensor_ops/conv2d/cudnn_kernel.rs b/src/tensor_ops/conv2d/cudnn_kernel.rs index 48972b584..8e614e5eb 100644 --- a/src/tensor_ops/conv2d/cudnn_kernel.rs +++ b/src/tensor_ops/conv2d/cudnn_kernel.rs @@ -2,6 +2,7 @@ use cudarc::cudnn::{self, Conv2dBackwardData, Conv2dBackwardFilter, Conv2dForwar use cudarc::driver::DeviceSlice; use crate::{ + dtypes::*, shapes::*, tensor::{Cuda, Tensor, Tensorlike}, }; @@ -10,7 +11,9 @@ use std::sync::Arc; trait HasCudnnKernel {} #[cfg(feature = "f16")] -impl HasCudnnKernel for Cuda {} +impl HasCudnnKernel for Cuda {} +#[cfg(feature = "f16")] +impl HasCudnnKernel> for Cuda {} impl HasCudnnKernel for Cuda {} impl HasCudnnKernel for Cuda {} diff --git a/src/tensor_ops/convtrans2d/cuda_kernel.rs b/src/tensor_ops/convtrans2d/cuda_kernel.rs index ba619f23c..6ddcd10c7 100644 --- a/src/tensor_ops/convtrans2d/cuda_kernel.rs +++ b/src/tensor_ops/convtrans2d/cuda_kernel.rs @@ -2,6 +2,7 @@ use cudarc::cublas::{CudaBlas, Gemm}; use cudarc::driver::{DeviceRepr, LaunchAsync, ValidAsZeroBits}; use crate::{ + dtypes::*, shapes::*, tensor::{launch_cfg, Cuda, Tensor, Tensorlike}, }; @@ -18,7 +19,17 @@ trait HasCudaKernel { } #[cfg(feature = "f16")] -impl HasCudaKernel for Cuda { +impl HasCudaKernel> for Cuda { + const MOD: &'static str = "convtrans2d_f16"; + const FNS: &'static [&'static str] = &[ + "unfold_input_into_patches_f16", + "unfold_output_into_patches_f16", + "transpose_filters_f16", + ]; +} + +#[cfg(feature = "f16")] +impl HasCudaKernel for Cuda { const MOD: &'static str = "convtrans2d_f16"; const FNS: &'static [&'static str] = &[ "unfold_input_into_patches_f16", diff --git a/src/tensor_ops/cos/cuda_kernel.rs b/src/tensor_ops/cos/cuda_kernel.rs index 904ba10ef..71b398a16 100644 --- a/src/tensor_ops/cos/cuda_kernel.rs +++ b/src/tensor_ops/cos/cuda_kernel.rs @@ -1,13 +1,17 @@ +#[allow(unused_imports)] +use crate::dtypes::*; use crate::tensor_ops::cuda_kernels::cuda_unary; unsafe impl cudarc::driver::DeviceRepr for super::CosKernelOp {} const PTX: &str = include_str!(concat!(env!("OUT_DIR"), "/cos.ptx")); +#[cfg(feature = "f16")] +cuda_unary!(super::CosKernelOp, f16, PTX, "cos_fwd_f16", "cos_bwd_f16"); #[cfg(feature = "f16")] cuda_unary!( super::CosKernelOp, - half::f16, + AMP, PTX, "cos_fwd_f16", "cos_bwd_f16" diff --git a/src/tensor_ops/div/cuda_kernel.rs b/src/tensor_ops/div/cuda_kernel.rs index 48407719c..7984baba1 100644 --- a/src/tensor_ops/div/cuda_kernel.rs +++ b/src/tensor_ops/div/cuda_kernel.rs @@ -1,8 +1,12 @@ use super::{BinaryDivKernelOp as Binary, ScalarDivKernelOp as Scalar}; +#[allow(unused_imports)] +use crate::dtypes::*; use crate::tensor_ops::cuda_kernels::{cuda_binary, cuda_unary}; #[cfg(feature = "f16")] -unsafe impl cudarc::driver::DeviceRepr for Scalar {} +unsafe impl cudarc::driver::DeviceRepr for Scalar {} +#[cfg(feature = "f16")] +unsafe impl cudarc::driver::DeviceRepr for Scalar> {} unsafe impl cudarc::driver::DeviceRepr for Scalar {} unsafe impl cudarc::driver::DeviceRepr for Scalar {} unsafe impl cudarc::driver::DeviceRepr for Binary {} @@ -11,13 +15,24 @@ const SCALAR_PTX: &str = include_str!(concat!(env!("OUT_DIR"), "/scalar_div.ptx" const BINARY_PTX: &str = include_str!(concat!(env!("OUT_DIR"), "/binary_div.ptx")); #[cfg(feature = "f16")] -cuda_unary!(const_df() Scalar, half::f16, SCALAR_PTX, "sdiv_fwd_f16", "sdiv_bwd_f16"); +cuda_unary!(const_df() Scalar, f16, SCALAR_PTX, "sdiv_fwd_f16", "sdiv_bwd_f16"); +#[cfg(feature = "f16")] +cuda_unary!(const_df() Scalar>, AMP, SCALAR_PTX, "sdiv_fwd_f16", "sdiv_bwd_f16"); cuda_unary!(const_df() Scalar, f32, SCALAR_PTX, "sdiv_fwd_f32", "sdiv_bwd_f32"); cuda_unary!(const_df() Scalar, f64, SCALAR_PTX, "sdiv_fwd_f64", "sdiv_bwd_f64"); #[cfg(feature = "f16")] cuda_binary!( Binary, - half::f16, + f16, + BINARY_PTX, + "bdiv_fwd_f16", + "bdiv_bwd_lhs_f16", + "bdiv_bwd_rhs_f16" +); +#[cfg(feature = "f16")] +cuda_binary!( + Binary, + AMP, BINARY_PTX, "bdiv_fwd_f16", "bdiv_bwd_lhs_f16", diff --git a/src/tensor_ops/div/mod.rs b/src/tensor_ops/div/mod.rs index e74580921..8eeb75ffa 100644 --- a/src/tensor_ops/div/mod.rs +++ b/src/tensor_ops/div/mod.rs @@ -83,6 +83,23 @@ impl, half::f16>, T: Tape< } } +#[cfg(feature = "f16")] +impl< + S: Shape, + D: UnaryKernel< + ScalarDivKernelOp>, + crate::dtypes::AMP, + >, + T: Tape, D>, + > TryDiv for Tensor, D, T> +{ + /// See [div] + fn try_div(self, rhs: f32) -> Result { + let scalar = crate::dtypes::AMP(half::f16::from_f32(rhs)); + try_unary_op(ScalarDivKernelOp { scalar }, self) + } +} + impl, LhsTape: Tape, Rhs> std::ops::Div for Tensor where diff --git a/src/tensor_ops/dropout/cuda_kernel.rs b/src/tensor_ops/dropout/cuda_kernel.rs index ca848b3dd..fe66e508e 100644 --- a/src/tensor_ops/dropout/cuda_kernel.rs +++ b/src/tensor_ops/dropout/cuda_kernel.rs @@ -1,4 +1,5 @@ use crate::{ + dtypes::*, shapes::*, tensor::{launch_cfg, Cuda, Tensor}, }; @@ -18,7 +19,13 @@ trait HasCudaKernel { } #[cfg(feature = "f16")] -impl HasCudaKernel for Cuda { +impl HasCudaKernel for Cuda { + const MOD: &'static str = "dropout_f16"; + const FNS: &'static [&'static str] = &["dropout_fwd_f16", "dropout_bwd_f16"]; +} + +#[cfg(feature = "f16")] +impl HasCudaKernel> for Cuda { const MOD: &'static str = "dropout_f16"; const FNS: &'static [&'static str] = &["dropout_fwd_f16", "dropout_bwd_f16"]; } diff --git a/src/tensor_ops/exp/cuda_kernel.rs b/src/tensor_ops/exp/cuda_kernel.rs index c2082962f..89f3cb0ee 100644 --- a/src/tensor_ops/exp/cuda_kernel.rs +++ b/src/tensor_ops/exp/cuda_kernel.rs @@ -1,3 +1,5 @@ +#[allow(unused_imports)] +use crate::dtypes::*; use crate::tensor_ops::cuda_kernels::cuda_unary; unsafe impl cudarc::driver::DeviceRepr for super::ExpKernelOp {} @@ -6,5 +8,7 @@ const PTX: &str = include_str!(concat!(env!("OUT_DIR"), "/exp.ptx")); #[cfg(feature = "f16")] cuda_unary!(df(f(x)) super::ExpKernelOp, half::f16, PTX, "exp_fwd_f16", "exp_bwd_f16"); +#[cfg(feature = "f16")] +cuda_unary!(df(f(x)) super::ExpKernelOp, AMP, PTX, "exp_fwd_f16", "exp_bwd_f16"); cuda_unary!(df(f(x)) super::ExpKernelOp, f32, PTX, "exp_fwd_f32", "exp_bwd_f32"); cuda_unary!(df(f(x)) super::ExpKernelOp, f64, PTX, "exp_fwd_f64", "exp_bwd_f64"); diff --git a/src/tensor_ops/fast_gelu/cuda_kernel.rs b/src/tensor_ops/fast_gelu/cuda_kernel.rs index 3e5107261..f836c6965 100644 --- a/src/tensor_ops/fast_gelu/cuda_kernel.rs +++ b/src/tensor_ops/fast_gelu/cuda_kernel.rs @@ -1,4 +1,6 @@ use super::FastGeLUKernelOp; +#[allow(unused_imports)] +use crate::dtypes::*; use crate::tensor_ops::cuda_kernels::cuda_unary; unsafe impl cudarc::driver::DeviceRepr for super::FastGeLUKernelOp {} @@ -8,7 +10,15 @@ const PTX: &str = include_str!(concat!(env!("OUT_DIR"), "/fast_gelu.ptx")); #[cfg(feature = "f16")] cuda_unary!( FastGeLUKernelOp, - half::f16, + f16, + PTX, + "fast_gelu_fwd_f16", + "fast_gelu_bwd_f16" +); +#[cfg(feature = "f16")] +cuda_unary!( + FastGeLUKernelOp, + AMP, PTX, "fast_gelu_fwd_f16", "fast_gelu_bwd_f16" diff --git a/src/tensor_ops/huber_error/cuda_kernel.rs b/src/tensor_ops/huber_error/cuda_kernel.rs index 6c936c2be..466ddb947 100644 --- a/src/tensor_ops/huber_error/cuda_kernel.rs +++ b/src/tensor_ops/huber_error/cuda_kernel.rs @@ -1,8 +1,12 @@ use super::HuberErrorKernelOp as HuberError; +#[allow(unused_imports)] +use crate::dtypes::*; use crate::tensor_ops::cuda_kernels::cuda_binary; #[cfg(feature = "f16")] -unsafe impl cudarc::driver::DeviceRepr for HuberError {} +unsafe impl cudarc::driver::DeviceRepr for HuberError {} +#[cfg(feature = "f16")] +unsafe impl cudarc::driver::DeviceRepr for HuberError> {} unsafe impl cudarc::driver::DeviceRepr for HuberError {} unsafe impl cudarc::driver::DeviceRepr for HuberError {} @@ -10,8 +14,17 @@ const PTX: &str = include_str!(concat!(env!("OUT_DIR"), "/huber_error.ptx")); #[cfg(feature = "f16")] cuda_binary!( - HuberError, - half::f16, + HuberError, + f16, + PTX, + "huber_fwd_f16", + "huber_bwd_lhs_f16", + "huber_bwd_rhs_f16" +); +#[cfg(feature = "f16")] +cuda_binary!( + HuberError>, + AMP, PTX, "huber_fwd_f16", "huber_bwd_lhs_f16", diff --git a/src/tensor_ops/ln/cuda_kernel.rs b/src/tensor_ops/ln/cuda_kernel.rs index 33a15186b..a78fbcce6 100644 --- a/src/tensor_ops/ln/cuda_kernel.rs +++ b/src/tensor_ops/ln/cuda_kernel.rs @@ -1,13 +1,17 @@ +#[allow(unused_imports)] +use crate::dtypes::*; use crate::tensor_ops::cuda_kernels::cuda_unary; unsafe impl cudarc::driver::DeviceRepr for super::LnKernelOp {} const PTX_SRC: &str = include_str!(concat!(env!("OUT_DIR"), "/ln.ptx")); +#[cfg(feature = "f16")] +cuda_unary!(super::LnKernelOp, f16, PTX_SRC, "ln_fwd_f16", "ln_bwd_f16"); #[cfg(feature = "f16")] cuda_unary!( super::LnKernelOp, - half::f16, + AMP, PTX_SRC, "ln_fwd_f16", "ln_bwd_f16" diff --git a/src/tensor_ops/matmul/cpu_kernel.rs b/src/tensor_ops/matmul/cpu_kernel.rs index ca205093a..9e9b497ae 100644 --- a/src/tensor_ops/matmul/cpu_kernel.rs +++ b/src/tensor_ops/matmul/cpu_kernel.rs @@ -49,6 +49,53 @@ pub(crate) trait MatMulImpl { ); } +#[cfg(feature = "f16")] +impl MatMulImpl> for Cpu { + #[inline] + fn matmul( + (m, k, n): (M, K, N), + accum: bool, + ap: *const crate::dtypes::AMP, + astr: [usize; 2], + bp: *const crate::dtypes::AMP, + bstr: [usize; 2], + cp: *mut crate::dtypes::AMP, + cstr: [usize; 2], + ) { + #[cfg(not(feature = "cpu"))] + naive_gemm((m, k, n), accum, ap, astr, bp, bstr, cp, cstr); + + #[cfg(feature = "cpu")] + unsafe { + gemm::gemm( + m.size(), + n.size(), + k.size(), + cp as *mut gemm::f16, + cstr[1] as isize, + cstr[0] as isize, + accum, + ap as *const gemm::f16, + astr[1] as isize, + astr[0] as isize, + bp as *const gemm::f16, + bstr[1] as isize, + bstr[0] as isize, + if accum { + gemm::f16::ONE + } else { + gemm::f16::ZERO + }, + gemm::f16::ONE, + false, + false, + false, + gemm::Parallelism::Rayon(rayon::current_num_threads()), + ) + } + } +} + #[cfg(feature = "f16")] impl MatMulImpl for Cpu { #[inline] diff --git a/src/tensor_ops/matmul/cuda_kernel.rs b/src/tensor_ops/matmul/cuda_kernel.rs index 7c1ffac37..b6787d848 100644 --- a/src/tensor_ops/matmul/cuda_kernel.rs +++ b/src/tensor_ops/matmul/cuda_kernel.rs @@ -1,4 +1,5 @@ use crate::{ + dtypes::*, shapes::*, tensor::{cuda::Cuda, Tensor}, }; @@ -58,6 +59,81 @@ fn gemm_cfg( } } +#[cfg(feature = "f16")] +impl Gemm> for CudaBlas { + unsafe fn gemm>, B: DevicePtr>, C: DevicePtrMut>>( + &self, + cfg: GemmConfig>, + a: &A, + b: &B, + c: &mut C, + ) -> Result<(), CublasError> { + let alpha: f32 = cfg.alpha.0.to_f32(); + let beta: f32 = cfg.beta.0.to_f32(); + cudarc::cublas::result::gemm_ex( + *self.handle(), + cfg.transa, + cfg.transb, + cfg.m, + cfg.n, + cfg.k, + (&alpha) as *const f32 as *const _, + *a.device_ptr() as *const _, + cudarc::cublas::sys::cudaDataType_t::CUDA_R_16F, + cfg.lda, + *b.device_ptr() as *const _, + cudarc::cublas::sys::cudaDataType_t::CUDA_R_16F, + cfg.ldb, + (&beta) as *const f32 as *const _, + *c.device_ptr_mut() as *mut _, + cudarc::cublas::sys::cudaDataType_t::CUDA_R_16F, + cfg.ldc, + cudarc::cublas::sys::cublasComputeType_t::CUBLAS_COMPUTE_32F, + cudarc::cublas::sys::cublasGemmAlgo_t::CUBLAS_GEMM_DEFAULT, + ) + } + + unsafe fn gemm_strided_batched< + A: DevicePtr>, + B: DevicePtr>, + C: DevicePtrMut>, + >( + &self, + cfg: StridedBatchedConfig>, + a: &A, + b: &B, + c: &mut C, + ) -> Result<(), CublasError> { + let alpha: f32 = cfg.gemm.alpha.0.to_f32(); + let beta: f32 = cfg.gemm.beta.0.to_f32(); + cudarc::cublas::result::gemm_strided_batched_ex( + *self.handle(), + cfg.gemm.transa, + cfg.gemm.transb, + cfg.gemm.m, + cfg.gemm.n, + cfg.gemm.k, + (&alpha) as *const f32 as *const _, + *a.device_ptr() as *const _, + cudarc::cublas::sys::cudaDataType_t::CUDA_R_16F, + cfg.gemm.lda, + cfg.stride_a, + *b.device_ptr() as *const _, + cudarc::cublas::sys::cudaDataType_t::CUDA_R_16F, + cfg.gemm.ldb, + cfg.stride_b, + (&beta) as *const f32 as *const _, + *c.device_ptr_mut() as *mut _, + cudarc::cublas::sys::cudaDataType_t::CUDA_R_16F, + cfg.gemm.ldc, + cfg.stride_c, + cfg.batch_size, + cudarc::cublas::sys::cublasComputeType_t::CUBLAS_COMPUTE_32F, + cudarc::cublas::sys::cublasGemmAlgo_t::CUBLAS_GEMM_DEFAULT, + ) + } +} + impl Cuda { /// sgemm helper. /// diff --git a/src/tensor_ops/max_to/cuda_kernel.rs b/src/tensor_ops/max_to/cuda_kernel.rs index 09c931392..31e384a6a 100644 --- a/src/tensor_ops/max_to/cuda_kernel.rs +++ b/src/tensor_ops/max_to/cuda_kernel.rs @@ -1,4 +1,5 @@ use crate::{ + dtypes::*, shapes::*, tensor::{launch_cfg, Cuda, Tensor}, tensor_ops::reduction_utils::*, @@ -17,8 +18,15 @@ trait HasCudaKernel { } #[cfg(feature = "f16")] -impl HasCudaKernel for Cuda { - const INIT: half::f16 = half::f16::NEG_INFINITY; +impl HasCudaKernel for Cuda { + const INIT: f16 = f16::NEG_INFINITY; + const MOD: &'static str = "max_f16"; + const FNS: &'static [&'static str] = &["max_to_fwd_f16", "max_to_bwd_f16", "fill_with_f16"]; +} + +#[cfg(feature = "f16")] +impl HasCudaKernel> for Cuda { + const INIT: AMP = AMP::::NEG_INFINITY; const MOD: &'static str = "max_f16"; const FNS: &'static [&'static str] = &["max_to_fwd_f16", "max_to_bwd_f16", "fill_with_f16"]; } diff --git a/src/tensor_ops/maximum/cuda_kernel.rs b/src/tensor_ops/maximum/cuda_kernel.rs index 7e71e6b62..9a5f56ea6 100644 --- a/src/tensor_ops/maximum/cuda_kernel.rs +++ b/src/tensor_ops/maximum/cuda_kernel.rs @@ -1,4 +1,6 @@ use super::MaximumKernelOp as Max; +#[allow(unused_imports)] +use crate::dtypes::*; use crate::tensor_ops::cuda_kernels::cuda_binary; unsafe impl cudarc::driver::DeviceRepr for Max {} @@ -8,7 +10,16 @@ const PTX: &str = include_str!(concat!(env!("OUT_DIR"), "/maximum.ptx")); #[cfg(feature = "f16")] cuda_binary!( Max, - half::f16, + f16, + PTX, + "maximum_fwd_f16", + "maximum_bwd_lhs_f16", + "maximum_bwd_rhs_f16" +); +#[cfg(feature = "f16")] +cuda_binary!( + Max, + AMP, PTX, "maximum_fwd_f16", "maximum_bwd_lhs_f16", diff --git a/src/tensor_ops/min_to/cuda_kernel.rs b/src/tensor_ops/min_to/cuda_kernel.rs index b8d23a399..8db3c1b5a 100644 --- a/src/tensor_ops/min_to/cuda_kernel.rs +++ b/src/tensor_ops/min_to/cuda_kernel.rs @@ -1,4 +1,5 @@ use crate::{ + dtypes::*, shapes::*, tensor::{launch_cfg, Cuda, Tensor}, tensor_ops::reduction_utils::*, @@ -17,8 +18,15 @@ trait HasCudaKernel { } #[cfg(feature = "f16")] -impl HasCudaKernel for Cuda { - const INIT: half::f16 = half::f16::INFINITY; +impl HasCudaKernel for Cuda { + const INIT: f16 = f16::INFINITY; + const MOD: &'static str = "min_f16"; + const FNS: &'static [&'static str] = &["min_to_fwd_f16", "min_to_bwd_f16", "fill_with_f16"]; +} + +#[cfg(feature = "f16")] +impl HasCudaKernel> for Cuda { + const INIT: AMP = AMP::::INFINITY; const MOD: &'static str = "min_f16"; const FNS: &'static [&'static str] = &["min_to_fwd_f16", "min_to_bwd_f16", "fill_with_f16"]; } diff --git a/src/tensor_ops/minimum/cuda_kernel.rs b/src/tensor_ops/minimum/cuda_kernel.rs index deb9a8f70..eba2b8d9b 100644 --- a/src/tensor_ops/minimum/cuda_kernel.rs +++ b/src/tensor_ops/minimum/cuda_kernel.rs @@ -1,4 +1,6 @@ use super::MinimumKernelOp as Min; +#[allow(unused_imports)] +use crate::dtypes::*; use crate::tensor_ops::cuda_kernels::cuda_binary; unsafe impl cudarc::driver::DeviceRepr for super::MinimumKernelOp {} @@ -8,7 +10,16 @@ const PTX: &str = include_str!(concat!(env!("OUT_DIR"), "/minimum.ptx")); #[cfg(feature = "f16")] cuda_binary!( Min, - half::f16, + f16, + PTX, + "minimum_fwd_f16", + "minimum_bwd_lhs_f16", + "minimum_bwd_rhs_f16" +); +#[cfg(feature = "f16")] +cuda_binary!( + Min, + AMP, PTX, "minimum_fwd_f16", "minimum_bwd_lhs_f16", diff --git a/src/tensor_ops/mul/cuda_kernel.rs b/src/tensor_ops/mul/cuda_kernel.rs index 9eca6a4fe..cf20072b8 100644 --- a/src/tensor_ops/mul/cuda_kernel.rs +++ b/src/tensor_ops/mul/cuda_kernel.rs @@ -1,8 +1,12 @@ use super::{BinaryMulKernelOp as Binary, ScalarMulKernelOp as Scalar}; +#[allow(unused_imports)] +use crate::dtypes::*; use crate::tensor_ops::cuda_kernels::{cuda_binary, cuda_unary}; #[cfg(feature = "f16")] -unsafe impl cudarc::driver::DeviceRepr for Scalar {} +unsafe impl cudarc::driver::DeviceRepr for Scalar {} +#[cfg(feature = "f16")] +unsafe impl cudarc::driver::DeviceRepr for Scalar> {} unsafe impl cudarc::driver::DeviceRepr for Scalar {} unsafe impl cudarc::driver::DeviceRepr for Scalar {} unsafe impl cudarc::driver::DeviceRepr for Binary {} @@ -11,13 +15,24 @@ const SCALAR_PTX: &str = include_str!(concat!(env!("OUT_DIR"), "/scalar_mul.ptx" const BINARY_PTX: &str = include_str!(concat!(env!("OUT_DIR"), "/binary_mul.ptx")); #[cfg(feature = "f16")] -cuda_unary!(const_df() Scalar, half::f16, SCALAR_PTX, "smul_fwd_f16", "smul_bwd_f16"); +cuda_unary!(const_df() Scalar, f16, SCALAR_PTX, "smul_fwd_f16", "smul_bwd_f16"); +#[cfg(feature = "f16")] +cuda_unary!(const_df() Scalar>, AMP, SCALAR_PTX, "smul_fwd_f16", "smul_bwd_f16"); cuda_unary!(const_df() Scalar, f32, SCALAR_PTX, "smul_fwd_f32", "smul_bwd_f32"); cuda_unary!(const_df() Scalar, f64, SCALAR_PTX, "smul_fwd_f64", "smul_bwd_f64"); #[cfg(feature = "f16")] cuda_binary!( Binary, - half::f16, + f16, + BINARY_PTX, + "bmul_fwd_f16", + "bmul_bwd_lhs_f16", + "bmul_bwd_rhs_f16" +); +#[cfg(feature = "f16")] +cuda_binary!( + Binary, + AMP, BINARY_PTX, "bmul_fwd_f16", "bmul_bwd_lhs_f16", diff --git a/src/tensor_ops/mul/mod.rs b/src/tensor_ops/mul/mod.rs index 28cba5f98..350e86b67 100644 --- a/src/tensor_ops/mul/mod.rs +++ b/src/tensor_ops/mul/mod.rs @@ -78,6 +78,22 @@ impl, half::f16>, T: Tape< } } +#[cfg(feature = "f16")] +impl< + S: Shape, + D: UnaryKernel< + ScalarMulKernelOp>, + crate::dtypes::AMP, + >, + T: Tape, D>, + > TryMul for Tensor, D, T> +{ + fn try_mul(self, rhs: f32) -> Result { + let scalar = crate::dtypes::AMP(half::f16::from_f32(rhs)); + try_unary_op(ScalarMulKernelOp { scalar }, self) + } +} + impl, LhsTape: Tape, Rhs> std::ops::Mul for Tensor where diff --git a/src/tensor_ops/nans_to/cuda_kernel.rs b/src/tensor_ops/nans_to/cuda_kernel.rs index dae060434..af141a510 100644 --- a/src/tensor_ops/nans_to/cuda_kernel.rs +++ b/src/tensor_ops/nans_to/cuda_kernel.rs @@ -1,17 +1,23 @@ use super::NansToKernelOp as NansTo; +#[allow(unused_imports)] +use crate::dtypes::*; use crate::tensor_ops::cuda_kernels::cuda_unary; #[cfg(feature = "f16")] -unsafe impl cudarc::driver::DeviceRepr for NansTo {} +unsafe impl cudarc::driver::DeviceRepr for NansTo {} +#[cfg(feature = "f16")] +unsafe impl cudarc::driver::DeviceRepr for NansTo> {} unsafe impl cudarc::driver::DeviceRepr for NansTo {} unsafe impl cudarc::driver::DeviceRepr for NansTo {} const PTX: &str = include_str!(concat!(env!("OUT_DIR"), "/nans_to.ptx")); +#[cfg(feature = "f16")] +cuda_unary!(NansTo, f16, PTX, "nans_to_fwd_f16", "nans_to_bwd_f16"); #[cfg(feature = "f16")] cuda_unary!( - NansTo, - half::f16, + NansTo>, + AMP, PTX, "nans_to_fwd_f16", "nans_to_bwd_f16" diff --git a/src/tensor_ops/negate/cuda_kernel.rs b/src/tensor_ops/negate/cuda_kernel.rs index a6065e555..456e06854 100644 --- a/src/tensor_ops/negate/cuda_kernel.rs +++ b/src/tensor_ops/negate/cuda_kernel.rs @@ -1,4 +1,6 @@ use super::NegateKernelOp; +#[allow(unused_imports)] +use crate::dtypes::*; use crate::tensor_ops::cuda_kernels::cuda_unary; unsafe impl cudarc::driver::DeviceRepr for NegateKernelOp {} @@ -6,6 +8,8 @@ unsafe impl cudarc::driver::DeviceRepr for NegateKernelOp {} const PTX: &str = include_str!(concat!(env!("OUT_DIR"), "/negate.ptx")); #[cfg(feature = "f16")] -cuda_unary!(const_df() NegateKernelOp, half::f16, PTX, "negate_fwd_f16", "negate_bwd_f16"); +cuda_unary!(const_df() NegateKernelOp, f16, PTX, "negate_fwd_f16", "negate_bwd_f16"); +#[cfg(feature = "f16")] +cuda_unary!(const_df() NegateKernelOp, AMP, PTX, "negate_fwd_f16", "negate_bwd_f16"); cuda_unary!(const_df() NegateKernelOp, f32, PTX, "negate_fwd_f32", "negate_bwd_f32"); cuda_unary!(const_df() NegateKernelOp, f64, PTX, "negate_fwd_f64", "negate_bwd_f64"); diff --git a/src/tensor_ops/pool2d/cuda_kernel.rs b/src/tensor_ops/pool2d/cuda_kernel.rs index 86527d9b3..720a516a0 100644 --- a/src/tensor_ops/pool2d/cuda_kernel.rs +++ b/src/tensor_ops/pool2d/cuda_kernel.rs @@ -1,4 +1,5 @@ use crate::{ + dtypes::*, shapes::*, tensor::{launch_cfg, Cuda, Tensor}, }; @@ -25,7 +26,13 @@ trait HasCudaKernel { } #[cfg(feature = "f16")] -impl HasCudaKernel for Cuda { +impl HasCudaKernel for Cuda { + const FWD: &'static str = "pool2d_fwd_f16"; + const BWD: &'static str = "pool2d_bwd_f16"; +} + +#[cfg(feature = "f16")] +impl HasCudaKernel> for Cuda { const FWD: &'static str = "pool2d_fwd_f16"; const BWD: &'static str = "pool2d_bwd_f16"; } diff --git a/src/tensor_ops/pow/cuda_kernel.rs b/src/tensor_ops/pow/cuda_kernel.rs index 97f3254c2..e698e2f4d 100644 --- a/src/tensor_ops/pow/cuda_kernel.rs +++ b/src/tensor_ops/pow/cuda_kernel.rs @@ -1,5 +1,6 @@ use super::PowfKernelOp; use crate::{ + dtypes::*, shapes::*, tensor::*, tensor_ops::{cuda_kernels::cuda_unary, ops::UnaryKernel}, @@ -7,16 +8,20 @@ use crate::{ use std::borrow::Cow; #[cfg(feature = "f16")] -unsafe impl cudarc::driver::DeviceRepr for super::PowfKernelOp {} +unsafe impl cudarc::driver::DeviceRepr for super::PowfKernelOp {} +#[cfg(feature = "f16")] +unsafe impl cudarc::driver::DeviceRepr for super::PowfKernelOp> {} unsafe impl cudarc::driver::DeviceRepr for super::PowfKernelOp {} unsafe impl cudarc::driver::DeviceRepr for super::PowfKernelOp {} const PTX: &str = include_str!(concat!(env!("OUT_DIR"), "/pow.ptx")); +#[cfg(feature = "f16")] +cuda_unary!(PowfKernelOp, f16, PTX, "pow_fwd_f16", "pow_bwd_f16"); #[cfg(feature = "f16")] cuda_unary!( - PowfKernelOp, - half::f16, + PowfKernelOp>, + AMP, PTX, "pow_fwd_f16", "pow_bwd_f16" diff --git a/src/tensor_ops/recip/cuda_kernel.rs b/src/tensor_ops/recip/cuda_kernel.rs index 145fc0eae..864a789db 100644 --- a/src/tensor_ops/recip/cuda_kernel.rs +++ b/src/tensor_ops/recip/cuda_kernel.rs @@ -1,4 +1,6 @@ use super::RecipKernelOp; +#[allow(unused_imports)] +use crate::dtypes::*; use crate::tensor_ops::cuda_kernels::cuda_unary; unsafe impl cudarc::driver::DeviceRepr for RecipKernelOp {} @@ -6,6 +8,8 @@ unsafe impl cudarc::driver::DeviceRepr for RecipKernelOp {} const PTX: &str = include_str!(concat!(env!("OUT_DIR"), "/recip.ptx")); #[cfg(feature = "f16")] -cuda_unary!(df(f(x)) RecipKernelOp, half::f16, PTX, "recip_fwd_f16", "recip_bwd_f16"); +cuda_unary!(df(f(x)) RecipKernelOp, f16, PTX, "recip_fwd_f16", "recip_bwd_f16"); +#[cfg(feature = "f16")] +cuda_unary!(df(f(x)) RecipKernelOp, AMP, PTX, "recip_fwd_f16", "recip_bwd_f16"); cuda_unary!(df(f(x)) RecipKernelOp, f32, PTX, "recip_fwd_f32", "recip_bwd_f32"); cuda_unary!(df(f(x)) RecipKernelOp, f64, PTX, "recip_fwd_f64", "recip_bwd_f64"); diff --git a/src/tensor_ops/relu/cuda_kernel.rs b/src/tensor_ops/relu/cuda_kernel.rs index 13a6fc80c..b4f1913f3 100644 --- a/src/tensor_ops/relu/cuda_kernel.rs +++ b/src/tensor_ops/relu/cuda_kernel.rs @@ -1,4 +1,6 @@ use super::ReLUKernelOp; +#[allow(unused_imports)] +use crate::dtypes::*; use crate::tensor_ops::cuda_kernels::cuda_unary; unsafe impl cudarc::driver::DeviceRepr for ReLUKernelOp {} @@ -6,6 +8,8 @@ unsafe impl cudarc::driver::DeviceRepr for ReLUKernelOp {} const PTX: &str = include_str!(concat!(env!("OUT_DIR"), "/relu.ptx")); #[cfg(feature = "f16")] -cuda_unary!(ReLUKernelOp, half::f16, PTX, "relu_fwd_f16", "relu_bwd_f16"); +cuda_unary!(ReLUKernelOp, f16, PTX, "relu_fwd_f16", "relu_bwd_f16"); +#[cfg(feature = "f16")] +cuda_unary!(ReLUKernelOp, AMP, PTX, "relu_fwd_f16", "relu_bwd_f16"); cuda_unary!(ReLUKernelOp, f32, PTX, "relu_fwd_f32", "relu_bwd_f32"); cuda_unary!(ReLUKernelOp, f64, PTX, "relu_fwd_f64", "relu_bwd_f64"); diff --git a/src/tensor_ops/rmsprop/cpu_kernel.rs b/src/tensor_ops/rmsprop/cpu_kernel.rs index 34c276d98..f7408651c 100644 --- a/src/tensor_ops/rmsprop/cpu_kernel.rs +++ b/src/tensor_ops/rmsprop/cpu_kernel.rs @@ -1,8 +1,77 @@ -use crate::{shapes::Dtype, tensor::cpu::Cpu}; +use crate::{ + dtypes::{Dtype, NotMixedPrecision}, + tensor::cpu::Cpu, +}; use super::{RMSpropConfig, RMSpropKernel, WeightDecay}; -impl RMSpropKernel for Cpu { +#[cfg(feature = "f16")] +impl RMSpropKernel> for Cpu { + fn rmsprop_kernel( + &self, + cfg: &RMSpropConfig, + param: &mut Self::Vec, + momentum: &mut Self::Vec, + square_avg: &mut Self::Vec, + grad_avg: &mut Self::Vec, + grad: &Self::Vec, + ) -> Result<(), Self::Err> { + let alpha = cfg.alpha as f32; + let eps = cfg.eps as f32; + let lr = cfg.lr as f32; + + for ((p, g), (s_avg, (g_avg, m))) in param.iter_mut().zip(grad.iter().cloned()).zip( + square_avg + .iter_mut() + .zip(grad_avg.iter_mut().zip(momentum.iter_mut())), + ) { + let p_f32 = p.0.to_f32(); + let mut g_f32 = g.0.to_f32(); + let mut s_avg_f32 = s_avg.0.to_f32(); + let mut g_avg_f32 = g_avg.0.to_f32(); + let mut m_f32 = m.0.to_f32(); + + if let Some(WeightDecay::L2(wd)) = cfg.weight_decay { + g_f32 += wd as f32 * p_f32; + } + + // sa = a * sa + (1 - a) * g^2 + s_avg_f32 += (1.0 - alpha) * (g_f32 * g_f32 - s_avg_f32); + + let avg = if cfg.centered { + // ga = a * ga + (1 - a) * g + g_avg_f32 += (1.0 - alpha) * (g_f32 - g_avg_f32); + // NOTE: eps in sqrt + (s_avg_f32 - g_avg_f32.powi(2) + eps).sqrt() + } else { + // NOTE: eps in sqrt + (s_avg_f32 + eps).sqrt() + }; + + g_f32 /= avg; + + match cfg.momentum { + Some(u) => { + m_f32 = m_f32 * (u as f32) + g_f32; + g_f32 = m_f32 * lr; + } + None => g_f32 *= lr, + } + + if let Some(WeightDecay::Decoupled(wd)) = cfg.weight_decay { + g_f32 += (wd * cfg.lr) as f32 * p_f32; + } + + p.0 = crate::dtypes::f16::from_f32(p_f32 - g_f32); + s_avg.0 = crate::dtypes::f16::from_f32(s_avg_f32); + g_avg.0 = crate::dtypes::f16::from_f32(g_avg_f32); + m.0 = crate::dtypes::f16::from_f32(m_f32); + } + Ok(()) + } +} + +impl RMSpropKernel for Cpu { fn rmsprop_kernel( &self, cfg: &RMSpropConfig, diff --git a/src/tensor_ops/rmsprop/cuda_kernel.rs b/src/tensor_ops/rmsprop/cuda_kernel.rs index bdb534f54..8491d44c3 100644 --- a/src/tensor_ops/rmsprop/cuda_kernel.rs +++ b/src/tensor_ops/rmsprop/cuda_kernel.rs @@ -1,6 +1,6 @@ use super::RMSpropConfig; use crate::{ - shapes::*, + dtypes::*, tensor::{launch_cfg, Cuda}, tensor_ops::optim::*, }; @@ -49,11 +49,17 @@ trait HasCudaKernel { } #[cfg(feature = "f16")] -impl HasCudaKernel for Cuda { +impl HasCudaKernel for Cuda { const MOD: &'static str = "rmsprop_f16"; const FWD: &'static str = "rmsprop_update_f16"; } +#[cfg(feature = "f16")] +impl HasCudaKernel> for Cuda { + const MOD: &'static str = "rmsprop_amp_f16"; + const FWD: &'static str = "rmsprop_update_amp_f16"; +} + impl HasCudaKernel for Cuda { const MOD: &'static str = "rmsprop_f32"; const FWD: &'static str = "rmsprop_update_f32"; diff --git a/src/tensor_ops/rmsprop/rmsprop.cu b/src/tensor_ops/rmsprop/rmsprop.cu index acdfe06ef..710191acb 100644 --- a/src/tensor_ops/rmsprop/rmsprop.cu +++ b/src/tensor_ops/rmsprop/rmsprop.cu @@ -94,3 +94,63 @@ extern "C" __global__ void FN( \ RMSPROP(__half, rmsprop_update_f16); RMSPROP(float, rmsprop_update_f32); RMSPROP(double, rmsprop_update_f64); + + +extern "C" __global__ void rmsprop_update_amp_f16( + const RMSpropConfig cfg, + const size_t numel, + __half* param, + __half* momentum, + __half* square_avg, + __half* grad_avg, + const __half* grad +) { + float lr = cfg.lr; + float alpha = cfg.alpha; + float eps = cfg.eps; + float momentum_ = cfg.momentum; + float weight_decay = cfg.weight_decay; + float one = 1.0; + + for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { + float p = param[i]; + float g = grad[i]; + float s_avg = square_avg[i]; + float g_avg = grad_avg[i]; + float m = momentum[i]; + + if (cfg.weight_decay_type == L2) { + g += weight_decay * p; + } + + s_avg += (one - alpha) * (g * g - s_avg); + + float avg; + + if (cfg.centered) { + // ga = a * ga + (1 - a) * g + g_avg += (one - alpha) * (g - g_avg); + avg = sqrtg(s_avg - g_avg * g_avg + eps); + } else { + avg = sqrtg(s_avg + eps); + }; + + g /= avg; + + if (cfg.has_momentum) { + m = m * momentum_ + g; + g = m * lr; + } else { + g *= lr; + } + + if (cfg.weight_decay_type == Decoupled) { + g += weight_decay * lr * p; + } + + square_avg[i] = s_avg; + grad_avg[i] = g_avg; + momentum[i] = m; + param[i] -= g; + } +} \ No newline at end of file diff --git a/src/tensor_ops/roll/cuda_kernel.rs b/src/tensor_ops/roll/cuda_kernel.rs index 0aa2ea7e4..aaf3bcf6c 100644 --- a/src/tensor_ops/roll/cuda_kernel.rs +++ b/src/tensor_ops/roll/cuda_kernel.rs @@ -1,7 +1,4 @@ -use crate::{ - shapes::{Dtype, Shape}, - tensor::*, -}; +use crate::{dtypes::*, shapes::Shape, tensor::*}; use cudarc::driver::{DeviceRepr, LaunchAsync}; @@ -13,7 +10,11 @@ trait HasCudaKernel { const FNS: &'static [&'static str]; } #[cfg(feature = "f16")] -impl HasCudaKernel for Cuda { +impl HasCudaKernel for Cuda { + const FNS: &'static [&'static str] = &["roll_fwd_f16", "roll_bwd_f16"]; +} +#[cfg(feature = "f16")] +impl HasCudaKernel> for Cuda { const FNS: &'static [&'static str] = &["roll_fwd_f16", "roll_bwd_f16"]; } impl HasCudaKernel for Cuda { diff --git a/src/tensor_ops/select_and_gather/cuda_kernel.rs b/src/tensor_ops/select_and_gather/cuda_kernel.rs index 4157e48f8..1d4ce5b5b 100644 --- a/src/tensor_ops/select_and_gather/cuda_kernel.rs +++ b/src/tensor_ops/select_and_gather/cuda_kernel.rs @@ -1,4 +1,6 @@ +#[allow(unused_imports)] use crate::{ + dtypes::*, shapes::{RemoveDimTo, ReplaceDimTo, Shape}, tensor::{launch_cfg, Cuda, Storage, Tensor}, }; @@ -189,7 +191,18 @@ macro_rules! impl_cuda_kernels { #[cfg(feature = "f16")] impl_cuda_kernels!( - half::f16, + f16, + "gather_f16", + "gather_fwd_f16", + "gather_bwd_f16", + "select_f16", + "select_fwd_f16", + "select_bwd_f16" +); + +#[cfg(feature = "f16")] +impl_cuda_kernels!( + AMP, "gather_f16", "gather_fwd_f16", "gather_bwd_f16", diff --git a/src/tensor_ops/sgd/cpu_kernel.rs b/src/tensor_ops/sgd/cpu_kernel.rs index 27f35c7b6..69c5653c4 100644 --- a/src/tensor_ops/sgd/cpu_kernel.rs +++ b/src/tensor_ops/sgd/cpu_kernel.rs @@ -1,8 +1,61 @@ -use crate::{shapes::Dtype, tensor::cpu::*}; +use crate::{ + dtypes::{Dtype, NotMixedPrecision}, + tensor::cpu::*, +}; use super::{Momentum, SgdConfig, SgdKernel, WeightDecay}; -impl SgdKernel for Cpu { +#[cfg(feature = "f16")] +impl SgdKernel> for Cpu { + fn sgd_kernel( + &self, + cfg: &SgdConfig, + param: &mut Self::Vec, + velocity: &mut Self::Vec, + grad: &Self::Vec, + ) -> Result<(), Self::Err> { + let lr = cfg.lr as f32; + + for ((p, g), v) in param + .iter_mut() + .zip(grad.iter().cloned()) + .zip(velocity.iter_mut()) + { + let p_f32 = p.0.to_f32(); + let mut g_f32 = g.0.to_f32(); + let mut v_f32 = v.0.to_f32(); + + if let Some(WeightDecay::L2(wd)) = cfg.weight_decay { + g_f32 += (wd as f32) * p_f32; + } + + match cfg.momentum { + Some(Momentum::Classic(u)) => { + let u = u as f32; + v_f32 = g_f32 + u * v_f32; + g_f32 = v_f32 * lr; + } + Some(Momentum::Nesterov(u)) => { + let u = u as f32; + v_f32 = g_f32 + u * v_f32; + g_f32 = (g_f32 + u * v_f32) * lr; + } + None => g_f32 *= lr, + } + + if let Some(WeightDecay::Decoupled(wd)) = cfg.weight_decay { + g_f32 += (wd * cfg.lr) as f32 * p_f32; + } + + p.0 = crate::dtypes::f16::from_f32(p_f32 - g_f32); + v.0 = crate::dtypes::f16::from_f32(v_f32); + } + + Ok(()) + } +} + +impl SgdKernel for Cpu { fn sgd_kernel( &self, cfg: &SgdConfig, diff --git a/src/tensor_ops/sgd/cuda_kernel.rs b/src/tensor_ops/sgd/cuda_kernel.rs index 0b839a8a9..6d29812fc 100644 --- a/src/tensor_ops/sgd/cuda_kernel.rs +++ b/src/tensor_ops/sgd/cuda_kernel.rs @@ -1,7 +1,7 @@ use super::SgdConfig; use crate::{ - shapes::*, + dtypes::*, tensor::{launch_cfg, Cuda}, tensor_ops::optim::*, }; @@ -40,11 +40,17 @@ trait HasCudaKernel { } #[cfg(feature = "f16")] -impl HasCudaKernel for Cuda { +impl HasCudaKernel for Cuda { const MOD: &'static str = "sgd_f16"; const FWD: &'static str = "sgd_update_f16"; } +#[cfg(feature = "f16")] +impl HasCudaKernel> for Cuda { + const MOD: &'static str = "sgd_amp_f16"; + const FWD: &'static str = "sgd_update_amp_f16"; +} + impl HasCudaKernel for Cuda { const MOD: &'static str = "sgd_f32"; const FWD: &'static str = "sgd_update_f32"; diff --git a/src/tensor_ops/sgd/sgd.cu b/src/tensor_ops/sgd/sgd.cu index 2c666cb1d..7c3d4c909 100644 --- a/src/tensor_ops/sgd/sgd.cu +++ b/src/tensor_ops/sgd/sgd.cu @@ -74,3 +74,43 @@ extern "C" __global__ void FN( \ SGD(__half, sgd_update_f16); SGD(float, sgd_update_f32); SGD(double, sgd_update_f64); + + +extern "C" __global__ void sgd_update_amp_f16( + const SgdConfig cfg, + const size_t numel, + __half* param, + __half* velocity, + const __half* grad +) { + float weight_decay = cfg.weight_decay; + float lr = cfg.lr; + float momentum = cfg.momentum; + + for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { + float p = param[i]; + float g = grad[i]; + float v = velocity[i]; + + if (cfg.weight_decay_type == L2) { + g += weight_decay * p; + } + + if (cfg.momentum_type == Classic) { + v = g + momentum * v; + g = v * lr; + } else if (cfg.momentum_type == Nesterov) { + v = g + momentum * v; + g = (g + momentum * v) * lr; + } else { + g *= lr; + } + + if (cfg.weight_decay_type == Decoupled) { + g += weight_decay * lr * p; + } + + velocity[i] = v; + param[i] -= g; + } +} \ No newline at end of file diff --git a/src/tensor_ops/sigmoid/cuda_kernel.rs b/src/tensor_ops/sigmoid/cuda_kernel.rs index 6d3b55110..2b38cb552 100644 --- a/src/tensor_ops/sigmoid/cuda_kernel.rs +++ b/src/tensor_ops/sigmoid/cuda_kernel.rs @@ -1,4 +1,6 @@ use super::SigmoidKernelOp as Sigmoid; +#[allow(unused_imports)] +use crate::dtypes::*; use crate::tensor_ops::cuda_kernels::cuda_unary; unsafe impl cudarc::driver::DeviceRepr for Sigmoid {} @@ -6,6 +8,8 @@ unsafe impl cudarc::driver::DeviceRepr for Sigmoid {} const PTX: &str = include_str!(concat!(env!("OUT_DIR"), "/sigmoid.ptx")); #[cfg(feature = "f16")] -cuda_unary!(df(f(x)) Sigmoid, half::f16, PTX, "sigmoid_fwd_f16", "sigmoid_bwd_f16"); +cuda_unary!(df(f(x)) Sigmoid, f16, PTX, "sigmoid_fwd_f16", "sigmoid_bwd_f16"); +#[cfg(feature = "f16")] +cuda_unary!(df(f(x)) Sigmoid, AMP, PTX, "sigmoid_fwd_f16", "sigmoid_bwd_f16"); cuda_unary!(df(f(x)) Sigmoid, f32, PTX, "sigmoid_fwd_f32", "sigmoid_bwd_f32"); cuda_unary!(df(f(x)) Sigmoid, f64, PTX, "sigmoid_fwd_f64", "sigmoid_bwd_f64"); diff --git a/src/tensor_ops/sin/cuda_kernel.rs b/src/tensor_ops/sin/cuda_kernel.rs index 9fd33010c..5b0e428b1 100644 --- a/src/tensor_ops/sin/cuda_kernel.rs +++ b/src/tensor_ops/sin/cuda_kernel.rs @@ -1,13 +1,17 @@ +#[allow(unused_imports)] +use crate::dtypes::*; use crate::tensor_ops::cuda_kernels::cuda_unary; unsafe impl cudarc::driver::DeviceRepr for super::SinKernelOp {} const PTX: &str = include_str!(concat!(env!("OUT_DIR"), "/sin.ptx")); +#[cfg(feature = "f16")] +cuda_unary!(super::SinKernelOp, f16, PTX, "sin_fwd_f16", "sin_bwd_f16"); #[cfg(feature = "f16")] cuda_unary!( super::SinKernelOp, - half::f16, + AMP, PTX, "sin_fwd_f16", "sin_bwd_f16" diff --git a/src/tensor_ops/slice/cuda_kernel.rs b/src/tensor_ops/slice/cuda_kernel.rs index 3ffd5705a..7e2f85d1a 100644 --- a/src/tensor_ops/slice/cuda_kernel.rs +++ b/src/tensor_ops/slice/cuda_kernel.rs @@ -1,4 +1,5 @@ use crate::{ + dtypes::*, prelude::cpu::NdIndex, shapes::*, tensor::{launch_cfg, Cuda, Tensor}, @@ -26,7 +27,13 @@ macro_rules! has_kernels { has_kernels!(u8, u16, u32, u64, usize, i8, i16, i32, i64, isize, bool); #[cfg(feature = "f16")] -impl HasCudaKernel for Cuda { +impl HasCudaKernel for Cuda { + const MOD: &'static str = "slice_f16"; + const FNS: &'static [&'static str] = &["slice_fwd_f16", "slice_bwd_f16"]; +} + +#[cfg(feature = "f16")] +impl HasCudaKernel> for Cuda { const MOD: &'static str = "slice_f16"; const FNS: &'static [&'static str] = &["slice_fwd_f16", "slice_bwd_f16"]; } diff --git a/src/tensor_ops/sqrt/cuda_kernel.rs b/src/tensor_ops/sqrt/cuda_kernel.rs index 6bd0ea39c..6adf4445b 100644 --- a/src/tensor_ops/sqrt/cuda_kernel.rs +++ b/src/tensor_ops/sqrt/cuda_kernel.rs @@ -1,4 +1,6 @@ use super::SqrtKernelOp; +#[allow(unused_imports)] +use crate::dtypes::*; use crate::tensor_ops::cuda_kernels::cuda_unary; unsafe impl cudarc::driver::DeviceRepr for SqrtKernelOp {} @@ -6,6 +8,8 @@ unsafe impl cudarc::driver::DeviceRepr for SqrtKernelOp {} const PTX: &str = include_str!(concat!(env!("OUT_DIR"), "/sqrt.ptx")); #[cfg(feature = "f16")] -cuda_unary!(df(f(x)) SqrtKernelOp, half::f16, PTX, "sqrt_fwd_f16", "sqrt_bwd_f16"); +cuda_unary!(df(f(x)) SqrtKernelOp, f16, PTX, "sqrt_fwd_f16", "sqrt_bwd_f16"); +#[cfg(feature = "f16")] +cuda_unary!(df(f(x)) SqrtKernelOp, AMP, PTX, "sqrt_fwd_f16", "sqrt_bwd_f16"); cuda_unary!(df(f(x)) SqrtKernelOp, f32, PTX, "sqrt_fwd_f32", "sqrt_bwd_f32"); cuda_unary!(df(f(x)) SqrtKernelOp, f64, PTX, "sqrt_fwd_f64", "sqrt_bwd_f64"); diff --git a/src/tensor_ops/square/cuda_kernel.rs b/src/tensor_ops/square/cuda_kernel.rs index 4f8a887a3..9309eee7a 100644 --- a/src/tensor_ops/square/cuda_kernel.rs +++ b/src/tensor_ops/square/cuda_kernel.rs @@ -1,14 +1,18 @@ use super::SquareKernelOp; +#[allow(unused_imports)] +use crate::dtypes::*; use crate::tensor_ops::cuda_kernels::cuda_unary; unsafe impl cudarc::driver::DeviceRepr for SquareKernelOp {} const PTX: &str = include_str!(concat!(env!("OUT_DIR"), "/square.ptx")); +#[cfg(feature = "f16")] +cuda_unary!(SquareKernelOp, f16, PTX, "square_fwd_f16", "square_bwd_f16"); #[cfg(feature = "f16")] cuda_unary!( SquareKernelOp, - half::f16, + AMP, PTX, "square_fwd_f16", "square_bwd_f16" diff --git a/src/tensor_ops/sub/cuda_kernel.rs b/src/tensor_ops/sub/cuda_kernel.rs index 60cb86ebd..5157ebe38 100644 --- a/src/tensor_ops/sub/cuda_kernel.rs +++ b/src/tensor_ops/sub/cuda_kernel.rs @@ -1,8 +1,12 @@ use super::{BinarySubKernelOp as Binary, ScalarSubKernelOp as Scalar}; +#[allow(unused_imports)] +use crate::dtypes::*; use crate::tensor_ops::cuda_kernels::{cuda_binary, cuda_unary}; #[cfg(feature = "f16")] -unsafe impl cudarc::driver::DeviceRepr for Scalar {} +unsafe impl cudarc::driver::DeviceRepr for Scalar {} +#[cfg(feature = "f16")] +unsafe impl cudarc::driver::DeviceRepr for Scalar> {} unsafe impl cudarc::driver::DeviceRepr for Scalar {} unsafe impl cudarc::driver::DeviceRepr for Scalar {} unsafe impl cudarc::driver::DeviceRepr for Binary {} @@ -11,13 +15,24 @@ const SCALAR_PTX: &str = include_str!(concat!(env!("OUT_DIR"), "/scalar_sub.ptx" const BINARY_PTX: &str = include_str!(concat!(env!("OUT_DIR"), "/binary_sub.ptx")); #[cfg(feature = "f16")] -cuda_unary!(const_df() Scalar, half::f16, SCALAR_PTX, "ssub_fwd_f16", "ssub_bwd_f16"); +cuda_unary!(const_df() Scalar, f16, SCALAR_PTX, "ssub_fwd_f16", "ssub_bwd_f16"); +#[cfg(feature = "f16")] +cuda_unary!(const_df() Scalar>, AMP, SCALAR_PTX, "ssub_fwd_f16", "ssub_bwd_f16"); cuda_unary!(const_df() Scalar, f32, SCALAR_PTX, "ssub_fwd_f32", "ssub_bwd_f32"); cuda_unary!(const_df() Scalar, f64, SCALAR_PTX, "ssub_fwd_f64", "ssub_bwd_f64"); #[cfg(feature = "f16")] cuda_binary!( const_df() Binary, - half::f16, + f16, + BINARY_PTX, + "bsub_fwd_f16", + "bsub_bwd_lhs_f16", + "bsub_bwd_rhs_f16" +); +#[cfg(feature = "f16")] +cuda_binary!( + const_df() Binary, + AMP, BINARY_PTX, "bsub_fwd_f16", "bsub_bwd_lhs_f16", diff --git a/src/tensor_ops/sub/mod.rs b/src/tensor_ops/sub/mod.rs index 1f8ab36aa..fb245fb1a 100644 --- a/src/tensor_ops/sub/mod.rs +++ b/src/tensor_ops/sub/mod.rs @@ -79,6 +79,22 @@ impl, half::f16>, T: Tape< } } +#[cfg(feature = "f16")] +impl< + S: Shape, + D: UnaryKernel< + ScalarSubKernelOp>, + crate::dtypes::AMP, + >, + T: Tape, D>, + > TrySub for Tensor, D, T> +{ + fn try_sub(self, rhs: f32) -> Result { + let scalar = crate::dtypes::AMP(half::f16::from_f32(rhs)); + try_unary_op(ScalarSubKernelOp { scalar }, self) + } +} + impl, LTape: Tape, Rhs> std::ops::Sub for Tensor where diff --git a/src/tensor_ops/sum_to/cpu_kernel.rs b/src/tensor_ops/sum_to/cpu_kernel.rs index 786c3d7e6..5daa7cced 100644 --- a/src/tensor_ops/sum_to/cpu_kernel.rs +++ b/src/tensor_ops/sum_to/cpu_kernel.rs @@ -1,10 +1,76 @@ use crate::{ - shapes::{Axes, Dtype, HasAxes, ReduceShapeTo, Shape}, + dtypes::{Dtype, NotMixedPrecision}, + shapes::{Axes, HasAxes, ReduceShapeTo, Shape}, tensor::{Cpu, Tensor, Tensorlike, ZerosTensor}, tensor_ops::utilities::reduction_utils::index_for_reductions, }; -impl super::SumKernel for Cpu { +#[cfg(feature = "f16")] +impl super::SumKernel> for Cpu { + fn forward( + &self, + dst: Dst, + inp: &Tensor, Self>, + ) -> Result, Self>, Self::Err> + where + Src: ReduceShapeTo, + { + let mut out = self.try_zeros_like(&dst)?; + if Dst::NUM_DIMS == 0 { + debug_assert_eq!(out.data.len(), 1); + + let mut tmp = 0.0f32; + for v in inp.buf_iter() { + tmp += v.0.to_f32(); + } + let scale = (inp.shape.num_elements() / inp.data.len()) as f32; + std::sync::Arc::get_mut(&mut out.data).unwrap()[0] = + crate::dtypes::AMP(crate::dtypes::f16::from_f32(tmp * scale)); + } else { + let num_elems_reduced = >::size(&inp.shape); + let inp_buf = inp.data.as_ref(); + let mut idx = index_for_reductions::(inp.shape, inp.strides); + for o in out.buf_iter_mut() { + let mut tmp = 0.0f32; + for _ in 0..num_elems_reduced { + tmp += inp_buf[idx.next().unwrap()].0.to_f32(); + } + *o = crate::dtypes::AMP(crate::dtypes::f16::from_f32(tmp)); + } + } + Ok(out) + } + fn backward( + &self, + _dst: Dst, + inp: &impl Tensorlike, Self>, + grad_inp: &mut Self::Vec, + grad_out: &Self::Vec, + ) -> Result<(), Self::Err> + where + Src: ReduceShapeTo, + { + if Dst::NUM_DIMS == 0 { + debug_assert_eq!(grad_out.len(), 1); + let v = grad_out[0].0.to_f32(); + let scale = (inp.shape().num_elements() / inp.len()) as f32; + for i in grad_inp.iter_mut() { + i.0 += crate::dtypes::f16::from_f32(v * scale); + } + } else { + let num_elems_reduced = >::size(inp.shape()); + let mut idx = index_for_reductions::(*inp.shape(), inp.strides()); + for &o in grad_out.iter() { + for _ in 0..num_elems_reduced { + grad_inp[idx.next().unwrap()] += o; + } + } + } + Ok(()) + } +} + +impl super::SumKernel for Cpu { fn forward( &self, dst: Dst, diff --git a/src/tensor_ops/sum_to/cuda_kernel.rs b/src/tensor_ops/sum_to/cuda_kernel.rs index 4be842048..05d1dec32 100644 --- a/src/tensor_ops/sum_to/cuda_kernel.rs +++ b/src/tensor_ops/sum_to/cuda_kernel.rs @@ -1,4 +1,5 @@ use crate::{ + dtypes::*, shapes::*, tensor::{launch_cfg, Cuda, Tensor, Tensorlike}, tensor_ops::reduction_utils::*, @@ -16,11 +17,17 @@ trait HasCudaKernel { } #[cfg(feature = "f16")] -impl HasCudaKernel for Cuda { +impl HasCudaKernel for Cuda { const MOD: &'static str = "sum_f16"; const FNS: &'static [&'static str] = &["sum_to_fwd_f16", "sum_to_bwd_f16"]; } +#[cfg(feature = "f16")] +impl HasCudaKernel> for Cuda { + const MOD: &'static str = "sum_amp_f16"; + const FNS: &'static [&'static str] = &["sum_to_fwd_amp_f16", "sum_to_bwd_f16"]; +} + impl HasCudaKernel for Cuda { const MOD: &'static str = "sum_f32"; const FNS: &'static [&'static str] = &["sum_to_fwd_f32", "sum_to_bwd_f32"]; diff --git a/src/tensor_ops/sum_to/sum_to.cu b/src/tensor_ops/sum_to/sum_to.cu index d0c9c7f43..a07dc983e 100644 --- a/src/tensor_ops/sum_to/sum_to.cu +++ b/src/tensor_ops/sum_to/sum_to.cu @@ -80,3 +80,74 @@ extern "C" __global__ void BWD( \ SUM(__half, sum_to_fwd_f16, sum_to_bwd_f16); SUM(float, sum_to_fwd_f32, sum_to_bwd_f32); SUM(double, sum_to_fwd_f64, sum_to_bwd_f64); + +__device__ void chunk_sum_amp_f16( + const size_t chunk_len, + const __half data, + __half* out +) { + __shared__ float buf[1024]; + + // assumes that threads where i >= numel have already exited + unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; + unsigned int block_i = threadIdx.x; + + // Fall back to atomicAdd if chunk_len is small to reduce overhead + if (chunk_len <= 2) { + atomicAdd(out + i / chunk_len, data); + return; + } + buf[block_i] = data; + + unsigned int chunk_i = i % chunk_len; + unsigned int chunk_start = max((int)(block_i - chunk_i), 0); + unsigned int chunk_end = min((unsigned int)(block_i + chunk_len - chunk_i), blockDim.x); + + chunk_i = block_i - chunk_start; + + size_t max_chunk_len = min(chunk_end - chunk_start, blockDim.x); + size_t incr = next_power_of_two(max_chunk_len) >> 1; + + __syncthreads(); + + // Uses sequential addressing as discussed in + // https://developer.download.nvidia.com/assets/cuda/files/reduction.pdf + for (; incr > 0; incr >>= 1) { + unsigned int block_i_2 = block_i + incr; + + if (block_i_2 < chunk_end && chunk_i < incr) { + // This is sound because __syncthreads and the conditions above + // ensure that no data races occur + buf[block_i] += buf[block_i_2]; + } + + __syncthreads(); + } + + if (block_i == chunk_start) { + __half y = buf[block_i]; + atomicAdd(out + i / chunk_len, y); + } +} + +extern "C" __global__ void sum_to_fwd_amp_f16( + const size_t numel, + const size_t num_dims, + const __half elems_per_thread, + const size_t chunk_len, + const size_t *info, + const __half *inp, + __half *out +) { + unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; + + if (i >= numel) { + return; + } + + const size_t *dims = info; + const size_t *strides = info + num_dims; + + unsigned int inp_i = get_strided_index(i, num_dims, dims, strides); + chunk_sum_amp_f16(chunk_len, inp[inp_i] * elems_per_thread, out); +} diff --git a/src/tensor_ops/tanh/cuda_kernel.rs b/src/tensor_ops/tanh/cuda_kernel.rs index c4ce875ae..2a3ab8f68 100644 --- a/src/tensor_ops/tanh/cuda_kernel.rs +++ b/src/tensor_ops/tanh/cuda_kernel.rs @@ -1,4 +1,6 @@ use super::TanhKernelOp; +#[allow(unused_imports)] +use crate::dtypes::*; use crate::tensor_ops::cuda_kernels::cuda_unary; unsafe impl cudarc::driver::DeviceRepr for TanhKernelOp {} @@ -6,6 +8,8 @@ unsafe impl cudarc::driver::DeviceRepr for TanhKernelOp {} const PTX: &str = include_str!(concat!(env!("OUT_DIR"), "/tanh.ptx")); #[cfg(feature = "f16")] -cuda_unary!(df(f(x)) TanhKernelOp, half::f16, PTX, "tanh_fwd_f16", "tanh_bwd_f16"); +cuda_unary!(df(f(x)) TanhKernelOp, f16, PTX, "tanh_fwd_f16", "tanh_bwd_f16"); +#[cfg(feature = "f16")] +cuda_unary!(df(f(x)) TanhKernelOp, AMP, PTX, "tanh_fwd_f16", "tanh_bwd_f16"); cuda_unary!(df(f(x)) TanhKernelOp, f32, PTX, "tanh_fwd_f32", "tanh_bwd_f32"); cuda_unary!(df(f(x)) TanhKernelOp, f64, PTX, "tanh_fwd_f64", "tanh_bwd_f64"); diff --git a/src/tensor_ops/upscale2d/cuda_kernel.rs b/src/tensor_ops/upscale2d/cuda_kernel.rs index 030abbcc1..038a86ae8 100644 --- a/src/tensor_ops/upscale2d/cuda_kernel.rs +++ b/src/tensor_ops/upscale2d/cuda_kernel.rs @@ -1,4 +1,5 @@ use crate::{ + dtypes::*, shapes::*, tensor::{launch_cfg, Cuda, Tensor}, }; @@ -26,12 +27,22 @@ trait HasCudaKernel { const BWD: &'static str; } #[cfg(feature = "f16")] -impl HasCudaKernel for Cuda { +impl HasCudaKernel for Cuda { const FWD: &'static str = "nearest_upscale2d_fwd_f16"; const BWD: &'static str = "nearest_upscale2d_bwd_f16"; } #[cfg(feature = "f16")] -impl HasCudaKernel for Cuda { +impl HasCudaKernel for Cuda { + const FWD: &'static str = "bilinear_upscale2d_fwd_f16"; + const BWD: &'static str = "bilinear_upscale2d_bwd_f16"; +} +#[cfg(feature = "f16")] +impl HasCudaKernel, NearestNeighbor> for Cuda { + const FWD: &'static str = "nearest_upscale2d_fwd_f16"; + const BWD: &'static str = "nearest_upscale2d_bwd_f16"; +} +#[cfg(feature = "f16")] +impl HasCudaKernel, Bilinear> for Cuda { const FWD: &'static str = "bilinear_upscale2d_fwd_f16"; const BWD: &'static str = "bilinear_upscale2d_bwd_f16"; } diff --git a/src/tensor_ops/utilities/device.rs b/src/tensor_ops/utilities/device.rs index d226d2e29..8a195d6a1 100644 --- a/src/tensor_ops/utilities/device.rs +++ b/src/tensor_ops/utilities/device.rs @@ -1,6 +1,6 @@ use super::super::ops::{BinaryKernel, UnaryKernel}; use crate::{ - shapes::Dtype, + dtypes::*, tensor::{CopySlice, RandomU64, Storage}, }; @@ -113,15 +113,17 @@ pub trait Device: } #[cfg(feature = "f16")] -impl Device for crate::tensor::Cpu {} +impl Device for crate::tensor::Cpu {} +#[cfg(feature = "f16")] +impl Device> for crate::tensor::Cpu {} impl Device for crate::tensor::Cpu {} impl Device for crate::tensor::Cpu {} #[cfg(all(feature = "cuda", feature = "f16"))] -impl Device for crate::tensor::Cuda {} - +impl Device for crate::tensor::Cuda {} +#[cfg(all(feature = "cuda", feature = "f16"))] +impl Device> for crate::tensor::Cuda {} #[cfg(feature = "cuda")] impl Device for crate::tensor::Cuda {} - #[cfg(feature = "cuda")] impl Device for crate::tensor::Cuda {}