From 5c532ec5dc51cd17cd4bb9ae940ecf2c9baf89f6 Mon Sep 17 00:00:00 2001 From: rainiwu Date: Fri, 26 Jan 2024 00:29:35 -0800 Subject: [PATCH 1/6] remove deprecated ftz intrinsics --- dfdx-core/src/lib.rs | 38 -------------------------------------- dfdx/examples/12-mnist.rs | 3 --- 2 files changed, 41 deletions(-) diff --git a/dfdx-core/src/lib.rs b/dfdx-core/src/lib.rs index 31e61643..c126db2c 100644 --- a/dfdx-core/src/lib.rs +++ b/dfdx-core/src/lib.rs @@ -128,44 +128,6 @@ pub mod prelude { pub use crate::tensor_ops::*; } -/// Sets a CPU `sse` flag to flush denormal floating point numbers to zero. The opposite of this is [keep_denormals()]. -/// -/// Some resources: -/// 1. [Effects of Flush-To-Zero mode](https://developer.arm.com/documentation/dui0473/c/neon-and-vfp-programming/the-effects-of-using-flush-to-zero-mode?lang=en) -/// 2. [When to use Flush-To-Zero mode](https://developer.arm.com/documentation/dui0473/c/neon-and-vfp-programming/when-to-use-flush-to-zero-mode?lang=en) -pub fn flush_denormals_to_zero() { - #[cfg(all(target_arch = "x86", target_feature = "sse"))] - { - use std::arch::x86::{_MM_FLUSH_ZERO_ON, _MM_SET_FLUSH_ZERO_MODE}; - unsafe { _MM_SET_FLUSH_ZERO_MODE(_MM_FLUSH_ZERO_ON) } - } - - #[cfg(all(target_arch = "x86_64", target_feature = "sse"))] - { - use std::arch::x86_64::{_MM_FLUSH_ZERO_ON, _MM_SET_FLUSH_ZERO_MODE}; - unsafe { _MM_SET_FLUSH_ZERO_MODE(_MM_FLUSH_ZERO_ON) } - } -} - -/// Sets a CPU flag to keep denormal floating point numbers. The opposite of this is [flush_denormals_to_zero()]. -/// -/// Some resources: -/// 1. [Effects of Flush-To-Zero mode](https://developer.arm.com/documentation/dui0473/c/neon-and-vfp-programming/the-effects-of-using-flush-to-zero-mode?lang=en) -/// 2. [When to use Flush-To-Zero mode](https://developer.arm.com/documentation/dui0473/c/neon-and-vfp-programming/when-to-use-flush-to-zero-mode?lang=en) -pub fn keep_denormals() { - #[cfg(all(target_arch = "x86", target_feature = "sse"))] - { - use std::arch::x86::{_MM_FLUSH_ZERO_OFF, _MM_SET_FLUSH_ZERO_MODE}; - unsafe { _MM_SET_FLUSH_ZERO_MODE(_MM_FLUSH_ZERO_OFF) } - } - - #[cfg(all(target_arch = "x86_64", target_feature = "sse"))] - { - use std::arch::x86_64::{_MM_FLUSH_ZERO_OFF, _MM_SET_FLUSH_ZERO_MODE}; - unsafe { _MM_SET_FLUSH_ZERO_MODE(_MM_FLUSH_ZERO_OFF) } - } -} - #[cfg(test)] pub(crate) mod tests { pub use num_traits::{Float, NumCast, Zero}; diff --git a/dfdx/examples/12-mnist.rs b/dfdx/examples/12-mnist.rs index 705d14c8..00d43452 100644 --- a/dfdx/examples/12-mnist.rs +++ b/dfdx/examples/12-mnist.rs @@ -62,9 +62,6 @@ type Mlp = ( const BATCH_SIZE: usize = 32; fn main() { - // ftz substantially improves performance - dfdx::flush_denormals_to_zero(); - let mnist_path = std::env::args() .nth(1) .unwrap_or_else(|| "./datasets/MNIST/raw".to_string()); From fb91f13314fb24a67c2d8e14ad40345d2d334805 Mon Sep 17 00:00:00 2001 From: rainiwu Date: Fri, 26 Jan 2024 00:55:48 -0800 Subject: [PATCH 2/6] suppress spurious cargo clippy warning --- dfdx-core/src/data/collate.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/dfdx-core/src/data/collate.rs b/dfdx-core/src/data/collate.rs index d38a2a67..5f52d636 100644 --- a/dfdx-core/src/data/collate.rs +++ b/dfdx-core/src/data/collate.rs @@ -55,6 +55,7 @@ impl Collate for Vec<(A, B)> { impl<'a, A, B> Collate for Vec<&'a (A, B)> { type Collated = (Vec<&'a A>, Vec<&'a B>); fn collated(self) -> Self::Collated { + #[allow(clippy::map_identity)] self.into_iter().map(|(a, b)| (a, b)).unzip() } } From 4e3f7c7a24728668f72cf3617a66f4476280f6fb Mon Sep 17 00:00:00 2001 From: Thiago Machado Date: Tue, 6 Feb 2024 18:27:46 -0500 Subject: [PATCH 3/6] avoid conv1d bound for cudnn --- dfdx-core/src/tensor_ops/utilities/device.rs | 50 +++++++++++++++----- 1 file changed, 39 insertions(+), 11 deletions(-) diff --git a/dfdx-core/src/tensor_ops/utilities/device.rs b/dfdx-core/src/tensor_ops/utilities/device.rs index 8cbc2137..91f87cf6 100644 --- a/dfdx-core/src/tensor_ops/utilities/device.rs +++ b/dfdx-core/src/tensor_ops/utilities/device.rs @@ -114,25 +114,49 @@ pub trait Device: + crate::tensor_ops::axpy::AxpyKernel // conv1d - + super::super::conv1d::Conv1DKernel + + NonCudnnCuda +{ +} + +#[cfg(feature = "cudnn")] +pub trait NonCudnnCuda {} + +#[cfg(not(feature = "cudnn"))] +pub trait NonCudnnCuda: + // conv1d + super::super::conv1d::Conv1DKernel { } #[cfg(feature = "f16")] -impl Device for crate::tensor::Cpu {} -#[cfg(feature = "f16")] -impl Device> for crate::tensor::Cpu {} +mod f16_ { + use super::*; + impl Device for crate::tensor::Cpu {} + impl NonCudnnCuda for crate::tensor::Cpu {} + impl Device> for crate::tensor::Cpu {} + impl NonCudnnCuda> for crate::tensor::Cpu {} +} impl Device for crate::tensor::Cpu {} +impl NonCudnnCuda for crate::tensor::Cpu {} impl Device for crate::tensor::Cpu {} +impl NonCudnnCuda for crate::tensor::Cpu {} #[cfg(all(feature = "cuda", feature = "f16"))] -impl Device for crate::tensor::Cuda {} -#[cfg(all(feature = "cuda", feature = "f16"))] -impl Device> for crate::tensor::Cuda {} -#[cfg(feature = "cuda")] -impl Device for crate::tensor::Cuda {} +mod cuda_f16 { + use super::*; + impl Device for crate::tensor::Cuda {} + impl NonCudnnCuda for crate::tensor::Cuda {} + impl Device> for crate::tensor::Cuda {} + impl NonCudnnCuda> for crate::tensor::Cuda {} +} #[cfg(feature = "cuda")] -impl Device for crate::tensor::Cuda {} +mod cuda { + use super::*; + impl Device for crate::tensor::Cuda {} + impl NonCudnnCuda for crate::tensor::Cuda {} + impl Device for crate::tensor::Cuda {} + impl NonCudnnCuda for crate::tensor::Cuda {} +} // TODO: How can we implement this for f16 when WGSL doesn't support f16 yet? // #[cfg(all(feature = "webgpu", feature = "f16"))] @@ -140,7 +164,11 @@ impl Device for crate::tensor::Cuda {} // #[cfg(all(feature = "webgpu", feature = "f16"))] // impl Device> for crate::tensor::Webgpu {} #[cfg(feature = "webgpu")] -impl Device for crate::tensor::Webgpu {} +mod webgpu { + use super::*; + impl Device for crate::tensor::Webgpu {} + impl NonCudnnCuda for crate::tensor::Webgpu {} +} // TODO: How can we implement this for f64 when WGSL doesn't support f64 yet? // #[cfg(feature = "webgpu")] From a8bc54c5c8e02c68fe09e72fc94ba0a8b3273b9a Mon Sep 17 00:00:00 2001 From: Thiago Machado Date: Fri, 9 Feb 2024 11:53:40 -0500 Subject: [PATCH 4/6] bump gemm --- dfdx-core/Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dfdx-core/Cargo.toml b/dfdx-core/Cargo.toml index 5309ef7c..0f6cd5c6 100644 --- a/dfdx-core/Cargo.toml +++ b/dfdx-core/Cargo.toml @@ -35,7 +35,7 @@ num-traits = { workspace = true } safetensors = { workspace = true, optional = true } memmap2 = { workspace = true, optional = true } half = { version = "2.3.1", optional = true, features = ["num-traits", "rand_distr"] } -gemm = { version = "0.16.14", default-features = false, optional = true, features = ["rayon"] } +gemm = { version = "0.17.1", default-features = false, optional = true, features = ["rayon"] } rayon = { version = "1.7.0", optional = true } libm = { workspace = true } wgpu = { version = "0.18.0", features = ["glsl", "spirv"], optional = true } From 557687c0a9e29dfba2311fe67414863c6c5137bf Mon Sep 17 00:00:00 2001 From: Thiago Machado Date: Fri, 9 Feb 2024 12:52:05 -0500 Subject: [PATCH 5/6] clippy fix --- dfdx-core/src/tensor/gradients.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dfdx-core/src/tensor/gradients.rs b/dfdx-core/src/tensor/gradients.rs index 86974ec6..d24e2e32 100644 --- a/dfdx-core/src/tensor/gradients.rs +++ b/dfdx-core/src/tensor/gradients.rs @@ -153,7 +153,7 @@ impl> Gradients { #[inline] pub(crate) fn many_and_ref( &mut self, - ls: &Vec>, + ls: &[impl Tensorlike], r: &impl Tensorlike, ) -> (Vec<&mut D::Vec>, &D::Vec) { for i in 0..ls.len() { From aa44eb5165dffc305081d16f9a66129ca4969e78 Mon Sep 17 00:00:00 2001 From: Thiago Machado Date: Thu, 1 Feb 2024 08:04:45 -0500 Subject: [PATCH 6/6] Add split_tensor_along method - Add `TrySplitShapeAlong` and `TrySplitTensorAlong`. - Minor linting and docs fix. TODO - Check if the tape should be returned. If not, it can be removed from the interface. - Add cuda kernel. - Consider a different interface, where it could get split in more than two tensors - possibly stated on a vec. In this way it could get closer to the pytorch interface (chunks). --- .../concat_tensor_along/cpu_kernel.rs | 8 +- .../src/tensor_ops/concat_tensor_along/mod.rs | 8 +- dfdx-core/src/tensor_ops/mod.rs | 4 + .../src/tensor_ops/split_shape_along/mod.rs | 158 ++++++++++ .../split_tensor_along/cpu_kernel.rs | 99 +++++++ .../split_tensor_along/cuda_kernel.rs | 31 ++ .../src/tensor_ops/split_tensor_along/mod.rs | 275 ++++++++++++++++++ .../split_tensor_along/webgpu_kernel.rs | 26 ++ dfdx-core/src/tensor_ops/utilities/device.rs | 3 + 9 files changed, 604 insertions(+), 8 deletions(-) create mode 100644 dfdx-core/src/tensor_ops/split_shape_along/mod.rs create mode 100644 dfdx-core/src/tensor_ops/split_tensor_along/cpu_kernel.rs create mode 100644 dfdx-core/src/tensor_ops/split_tensor_along/cuda_kernel.rs create mode 100644 dfdx-core/src/tensor_ops/split_tensor_along/mod.rs create mode 100644 dfdx-core/src/tensor_ops/split_tensor_along/webgpu_kernel.rs diff --git a/dfdx-core/src/tensor_ops/concat_tensor_along/cpu_kernel.rs b/dfdx-core/src/tensor_ops/concat_tensor_along/cpu_kernel.rs index e6ab2eb2..25efc27e 100644 --- a/dfdx-core/src/tensor_ops/concat_tensor_along/cpu_kernel.rs +++ b/dfdx-core/src/tensor_ops/concat_tensor_along/cpu_kernel.rs @@ -26,11 +26,11 @@ impl super::ConcatAlongKernel for Cpu { let buf = std::sync::Arc::get_mut(&mut c.data).unwrap(); while i < n { for _ in 0..a_n { - buf[i] = a.data[a_idx.next().unwrap()]; + (*buf)[i] = a.data[a_idx.next().unwrap()]; i += 1; } for _ in 0..b_n { - buf[i] = b.data[b_idx.next().unwrap()]; + (*buf)[i] = b.data[b_idx.next().unwrap()]; i += 1; } } @@ -59,11 +59,11 @@ impl super::ConcatAlongKernel for Cpu { let n = grad_out.len(); while i < n { for _ in 0..a_n { - grad_a[a_idx.next().unwrap()] += grad_out[i]; + (*grad_a)[a_idx.next().unwrap()] += grad_out[i]; i += 1; } for _ in 0..b_n { - grad_b[b_idx.next().unwrap()] += grad_out[i]; + (*grad_b)[b_idx.next().unwrap()] += grad_out[i]; i += 1; } } diff --git a/dfdx-core/src/tensor_ops/concat_tensor_along/mod.rs b/dfdx-core/src/tensor_ops/concat_tensor_along/mod.rs index 7462fd2b..9165efba 100644 --- a/dfdx-core/src/tensor_ops/concat_tensor_along/mod.rs +++ b/dfdx-core/src/tensor_ops/concat_tensor_along/mod.rs @@ -19,7 +19,7 @@ mod webgpu_kernel; /// # let dev: Cpu = Default::default(); /// let a: Tensor, f32, _> = dev.zeros(); /// let b: Tensor, f32, _> = dev.zeros(); -/// let _: Tensor, f32, _> = (a, b).concat_along(Axis::<0>); +/// let _: Tensor, f32, _> = (a, b).concat_tensor_along(Axis::<0>); /// ``` /// /// Along Axis 1: @@ -28,7 +28,7 @@ mod webgpu_kernel; /// # let dev: Cpu = Default::default(); /// let a: Tensor, f32, _> = dev.zeros(); /// let b: Tensor, f32, _> = dev.zeros(); -/// let _: Tensor, f32, _> = (a, b).concat_along(Axis::<1>); +/// let _: Tensor, f32, _> = (a, b).concat_tensor_along(Axis::<1>); /// ``` /// /// # [usize] dims @@ -38,7 +38,7 @@ mod webgpu_kernel; /// # let dev: Cpu = Default::default(); /// let a: Tensor<(usize, Const<3>), f32, _> = dev.zeros_like(&(2, Const)); /// let b: Tensor<(usize, Const<3>), f32, _> = dev.zeros_like(&(4, Const)); -/// let _: Tensor, f32, _> = (a, b).concat_along(Axis::<0>).realize(); +/// let _: Tensor, f32, _> = (a, b).concat_tensor_along(Axis::<0>).realize(); /// ``` /// /// Along Axis 1: @@ -47,7 +47,7 @@ mod webgpu_kernel; /// # let dev: Cpu = Default::default(); /// let a: Tensor<(Const<2>, usize), f32, _> = dev.zeros_like(&(Const, 2)); /// let b: Tensor<(Const<2>, usize), f32, _> = dev.zeros_like(&(Const, 4)); -/// let _: Tensor, f32, _> = (a, b).concat_along(Axis::<1>).realize(); +/// let _: Tensor, f32, _> = (a, b).concat_tensor_along(Axis::<1>).realize(); /// ``` pub trait TryConcatTensorAlong: Sized { type Output; diff --git a/dfdx-core/src/tensor_ops/mod.rs b/dfdx-core/src/tensor_ops/mod.rs index 453457f4..38a03d14 100644 --- a/dfdx-core/src/tensor_ops/mod.rs +++ b/dfdx-core/src/tensor_ops/mod.rs @@ -200,6 +200,8 @@ mod sigmoid; mod sin; mod slice; mod softmax; +mod split_shape_along; +mod split_tensor_along; mod sqrt; mod square; mod stack; @@ -267,6 +269,8 @@ pub use sigmoid::sigmoid; pub use sin::sin; pub use slice::slice; pub use softmax::softmax; +pub use split_shape_along::TrySplitShapeAlong; +pub use split_tensor_along::TrySplitTensorAlong; pub use sqrt::sqrt; pub use square::square; pub use stack::{AddDim, TryStack}; diff --git a/dfdx-core/src/tensor_ops/split_shape_along/mod.rs b/dfdx-core/src/tensor_ops/split_shape_along/mod.rs new file mode 100644 index 00000000..1421e12f --- /dev/null +++ b/dfdx-core/src/tensor_ops/split_shape_along/mod.rs @@ -0,0 +1,158 @@ +use crate::{shapes::*, tensor::*}; + +/// Split a shape in two along a given axis. +/// +/// # [Const] dims **requires nightly** +/// +/// Along Axis 0: +/// ```ignore +/// # use dfdx_core::prelude::*; +/// # let dev: Cpu = Default::default(); +/// let (a, b): (Rank2<3, 3>, Rank2<4, 3>) = (Const::<7>, Const::<3>).split_shape_along(Axis::<0>, Const::<3>, Const::<4>); +/// ``` +/// +/// Along Axis 1: +/// ```ignore +/// # use dfdx_core::prelude::*; +/// # let dev: Cpu = Default::default(); +/// let (a, b): (Rank2<7, 2>, Rank2<7, 1>) = (Const::<7>, Const::<3>).split_shape_along(Axis::<1>, Const::<2>, Const::<1>); +/// ``` +/// +/// # [usize] dims +/// Along Axis 0: +/// ```rust +/// # use dfdx_core::prelude::*; +/// # let dev: Cpu = Default::default(); +/// let (a, b) = (7, Const::<3>).split_shape_along(Axis::<0>, 3, 4); +/// assert_eq!(a, (3, Const::<3>)); +/// assert_eq!(b, (4, Const::<3>)); +/// ``` +/// +/// Along Axis 1: +/// ```rust +/// # use dfdx_core::prelude::*; +/// # let dev: Cpu = Default::default(); +/// let (a, b) = (Const::<7>, 3).split_shape_along(Axis::<1>, 2, 1); +/// assert_eq!(a, (Const::<7>, 2)); +/// assert_eq!(b, (Const::<7>, 1)); +/// ``` +pub trait TrySplitShapeAlong: Shape { + type Output; + + /// Splits self along the given axis. + fn split_shape_along(self, ax: Ax, a: A, b: B) -> Self::Output { + self.try_split_shape_along(ax, a, b).unwrap() + } + /// Fallibly splits self along the given axis. + fn try_split_shape_along(self, ax: Ax, a: A, b: B) -> Result; +} + +macro_rules! impl_split { + ($Ax:expr, $NumDims:expr, [$($Head:tt),*], [$($Tail:tt),*]) => { + impl TrySplitShapeAlong, A, B> + for + ( + $($Head, )* + AB, + $($Tail, )* + ) + where + ($($Head, )* A, $($Tail, )*): Shape::Concrete>, + ($($Head, )* B, $($Tail, )*): Shape::Concrete>, + { + type Output = + ( + ($($Head, )* A, $($Tail, )*), + ($($Head, )* B, $($Tail, )*), + ); + + fn try_split_shape_along(self, _: Axis<$Ax>, a: A, b: B) -> Result { + let dims = self.concrete(); + let mut lhs_dims = dims; + let mut rhs_dims = dims; + lhs_dims[$Ax] = a.size(); + rhs_dims[$Ax] = b.size(); + assert_eq!(dims[$Ax], lhs_dims[$Ax] + rhs_dims[$Ax]); + + Ok(( + <($($Head, )* A, $($Tail, )*)>::from_concrete(&lhs_dims).unwrap(), + <($($Head, )* B, $($Tail, )*)>::from_concrete(&rhs_dims).unwrap(), + )) + } + } + }; +} + +impl_split!(0, 1, [], []); +impl_split!(0, 2, [], [D1]); +impl_split!(0, 3, [], [D1, D2]); +impl_split!(0, 4, [], [D1, D2, D3]); +impl_split!(0, 5, [], [D1, D2, D3, D4]); +impl_split!(0, 6, [], [D1, D2, D3, D4, D5]); + +impl_split!(1, 2, [D0], []); +impl_split!(1, 3, [D0], [D2]); +impl_split!(1, 4, [D0], [D2, D3]); +impl_split!(1, 5, [D0], [D2, D3, D4]); +impl_split!(1, 6, [D0], [D2, D3, D4, D5]); + +impl_split!(2, 3, [D0, D1], []); +impl_split!(2, 4, [D0, D1], [D3]); +impl_split!(2, 5, [D0, D1], [D3, D4]); +impl_split!(2, 6, [D0, D1], [D3, D4, D5]); + +impl_split!(3, 4, [D0, D1, D2], []); +impl_split!(3, 5, [D0, D1, D2], [D4]); +impl_split!(3, 6, [D0, D1, D2], [D4, D5]); + +impl_split!(4, 5, [D0, D1, D2, D3], []); +impl_split!(4, 6, [D0, D1, D2, D3], [D5]); + +impl_split!(5, 6, [D0, D1, D2, D3, D4], []); + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_split_shape() { + let a: (usize, Const<5>) = (5, Const); + let b: (usize, Const<5>) = (3, Const); + assert_eq!( + (8, Const::<5>).split_shape_along(Axis::<0>, a.0, b.0), + (a, b) + ); + + let a: (Const<5>, Const<5>) = (Const, Const); + let b: (usize, Const<5>) = (3, Const); + assert_eq!( + (8, Const::<5>).split_shape_along(Axis::<0>, a.0, b.0), + (a, b) + ); + + let a: (usize, Const<5>) = (5, Const); + let b: (Const<3>, Const<5>) = (Const, Const); + assert_eq!( + (8, Const::<5>).split_shape_along(Axis::<0>, a.0, b.0), + (a, b) + ); + + #[cfg(feature = "nightly")] + { + let a: (Const<5>, Const<5>) = (Const, Const); + let b: (Const<3>, Const<5>) = (Const, Const); + assert_eq!( + (Const::<8>, Const::<5>).split_shape_along(Axis::<0>, a.0, b.0), + (a, b) + ); + } + } + + #[test] + #[should_panic = "left: 8\n right: 7"] + fn test_split_shape_fails() { + let a: (usize, Const<5>) = (4, Const); + let b: (usize, Const<5>) = (3, Const); + (8, Const::<5>).split_shape_along(Axis::<0>, a.0, b.0); + } +} diff --git a/dfdx-core/src/tensor_ops/split_tensor_along/cpu_kernel.rs b/dfdx-core/src/tensor_ops/split_tensor_along/cpu_kernel.rs new file mode 100644 index 00000000..3e2fa5e1 --- /dev/null +++ b/dfdx-core/src/tensor_ops/split_tensor_along/cpu_kernel.rs @@ -0,0 +1,99 @@ +use super::AorB; +use crate::{ + shapes::*, + tensor::{cpu::NdIndex, *}, +}; + +impl super::SplitAlongKernel for Cpu { + fn forward( + &self, + ax: usize, + ab: &Tensor, + a: &mut Tensor, + b: &mut Tensor, + ) -> Result<(), Error> { + let mut a_n = 1; + let mut b_n = 1; + { + let a_idx = NdIndex::new(a.shape, a.strides); + let b_idx = NdIndex::new(b.shape, b.strides); + for i in ax..A::NUM_DIMS { + a_n *= a_idx.shape[i]; + b_n *= b_idx.shape[i]; + } + } + + let n_ab = ab.data.len(); + + let buf_a = std::sync::Arc::get_mut(&mut a.data).unwrap(); + let buf_b = std::sync::Arc::get_mut(&mut b.data).unwrap(); + + let mut i = 0; + let mut k = 0; + let mut ab_idx = NdIndex::new(ab.shape, ab.strides); + while i < n_ab { + for j in 0..a_n { + (*buf_a)[j + k * a_n] = ab.data[ab_idx.next().unwrap()]; + i += 1; + } + for j in 0..b_n { + (*buf_b)[j + k * b_n] = ab.data[ab_idx.next().unwrap()]; + i += 1; + } + k += 1; + } + Ok(()) + } + + fn backward( + &self, + ax: usize, + ab: &GhostTensor, + grad_ab: &mut Self::Vec, + a: &GhostTensor, + b: &GhostTensor, + a_or_b: AorB, + grad_out: &Self::Vec, + ) -> Result<(), Error> { + let a_idx = NdIndex::new(a.shape, a.strides); + let b_idx = NdIndex::new(b.shape, b.strides); + + let mut a_n = 1; + let mut b_n = 1; + for i in ax..A::NUM_DIMS { + a_n *= a_idx.shape[i]; + b_n *= b_idx.shape[i]; + } + + let mut i = 0; + let mut j = 0; + let n = grad_ab.len(); + let mut ab_idx = NdIndex::new(ab.shape, ab.strides); + while i + j < n { + match a_or_b { + AorB::A => { + for _ in 0..a_n { + (*grad_ab)[ab_idx.next().unwrap()] = grad_out[i]; + i += 1; + } + for _ in 0..b_n { + ab_idx.next().unwrap(); + j += 1; + } + } + AorB::B => { + for _ in 0..a_n { + ab_idx.next().unwrap(); + j += 1; + } + for _ in 0..b_n { + (*grad_ab)[ab_idx.next().unwrap()] = grad_out[i]; + i += 1; + } + } + }; + } + + Ok(()) + } +} diff --git a/dfdx-core/src/tensor_ops/split_tensor_along/cuda_kernel.rs b/dfdx-core/src/tensor_ops/split_tensor_along/cuda_kernel.rs new file mode 100644 index 00000000..515f0365 --- /dev/null +++ b/dfdx-core/src/tensor_ops/split_tensor_along/cuda_kernel.rs @@ -0,0 +1,31 @@ +use super::AorB; +use crate::{ + shapes::*, + tensor::{Cuda, Error, GhostTensor, Tensor}, +}; +use cudarc::types::CudaTypeName; + +impl super::SplitAlongKernel for Cuda { + fn forward( + &self, + _ax: usize, + _ab: &Tensor, + _a: &mut Tensor, + _b: &mut Tensor, + ) -> Result<(), Error> { + todo!() + } + + fn backward( + &self, + _ax: usize, + _ab: &GhostTensor, + _grad_ab: &mut Self::Vec, + _a: &GhostTensor, + _b: &GhostTensor, + _a_or_b: AorB, + _grad_out: &Self::Vec, + ) -> Result<(), Error> { + todo!() + } +} diff --git a/dfdx-core/src/tensor_ops/split_tensor_along/mod.rs b/dfdx-core/src/tensor_ops/split_tensor_along/mod.rs new file mode 100644 index 00000000..ac619301 --- /dev/null +++ b/dfdx-core/src/tensor_ops/split_tensor_along/mod.rs @@ -0,0 +1,275 @@ +use super::split_shape_along::TrySplitShapeAlong; +use crate::{shapes::*, tensor::*}; + +pub(crate) mod cpu_kernel; +#[cfg(feature = "cuda")] +pub(crate) mod cuda_kernel; +#[cfg(feature = "webgpu")] +mod webgpu_kernel; + +/// Split a tensor in two along a given axis. +/// +/// This is the reverse of [TryConcatTensorAlong::concat_tensor_along]. +/// +/// # [Const] dims **requires nightly** +/// +/// Along Axis 0: +/// ```ignore +/// # use dfdx_core::prelude::*; +/// # let dev: Cpu = Default::default(); +/// let ab: Tensor, f32, _> = dev.zeros(); +/// let (a, b, _tape): ( +/// Tensor, f32, _>, +/// Tensor, f32, _>, +/// _ +/// ) = ab.split_tensor_along(Axis::<0>, Const::<2>, Const::<3>); +/// ``` +/// +/// Along Axis 1: +/// ```ignore +/// # use dfdx_core::prelude::*; +/// # let dev: Cpu = Default::default(); +/// let ab: Tensor, f32, _> = dev.zeros(); +/// let (a, b, _tape): ( +/// Tensor, f32, _>, +/// Tensor, f32, _>, +/// _ +/// ) = ab.split_tensor_along(Axis::<1>, Const::<2>, Const::<3>); +/// ``` +/// +/// # [usize] dims +/// Along Axis 0: +/// ```rust +/// # use dfdx_core::prelude::*; +/// # let dev: Cpu = Default::default(); +/// let ab: Tensor<(usize, Const::<4>), f32, _> = dev.zeros_like(&(5, Const)); +/// let (a, b, _tape): ( +/// Tensor<(usize, Const::<4>), f32, _>, +/// Tensor<(usize, Const::<4>), f32, _>, +/// _ +/// ) = ab.split_tensor_along(Axis::<0>, 2, 3); +/// let a: Tensor, f32, _> = a.realize(); +/// let b: Tensor, f32, _> = b.realize(); +/// ``` +/// +/// Along Axis 1: +/// ```rust +/// # use dfdx_core::prelude::*; +/// # let dev: Cpu = Default::default(); +/// let ab: Tensor<(Const::<4>, usize), f32, _> = dev.zeros_like(&(Const, 5)); +/// let (a, b, _tape): ( +/// Tensor<(Const::<4>, usize), f32, _>, +/// Tensor<(Const::<4>, usize), f32, _>, +/// _ +/// ) = ab.split_tensor_along(Axis::<1>, 2, 3); +/// let a: Tensor, f32, _> = a.realize(); +/// let b: Tensor, f32, _> = b.realize(); +/// ``` +pub trait TrySplitTensorAlong: Sized { + type Output; + + /// Splits self along the given axis. + fn split_tensor_along(self, ax: Ax, a: A, b: B) -> Self::Output { + self.try_split_tensor_along(ax, a, b).unwrap() + } + /// Fallibly splits self along the given axis. + fn try_split_tensor_along(self, ax: Ax, a: A, b: B) -> Result; +} + +#[derive(Debug, Clone)] +pub enum AorB { + A, + B, +} + +pub trait SplitAlongKernel: Storage { + fn forward( + &self, + ax: usize, + ab: &Tensor, + a: &mut Tensor, + b: &mut Tensor, + ) -> Result<(), Error>; + + #[allow(clippy::too_many_arguments)] + fn backward( + &self, + ax: usize, + ab: &GhostTensor, + grad_ab: &mut Self::Vec, + a: &GhostTensor, + b: &GhostTensor, + a_or_b: AorB, + grad_out: &Self::Vec, + ) -> Result<(), Error>; +} + +impl> TrySplitTensorAlong + for Tensor +where + Ax: Axes, + A: Dim, + B: Dim, + AS: Shape, + BS: Shape, + AB: Shape + TrySplitShapeAlong, + D: SplitAlongKernel + ZerosTensor, +{ + type Output = (Tensor, Tensor, T); + + fn try_split_tensor_along(self, ax: Ax, a: A, b: B) -> Result { + let device = self.device.clone(); + let (a_shape, b_shape) = (*self.shape()).try_split_shape_along(ax, a, b)?; + let ax = Ax::as_array()[0] as usize; + + let (ab, tape) = self.split_tape(); + + let mut at: Tensor = device.try_zeros_like(&a_shape)?; + let mut bt: Tensor = device.try_zeros_like(&b_shape)?; + + ab.device.forward(ax, &ab, &mut at, &mut bt)?; + + let mut ta = T::default(); + let mut tb = T::default(); + + let device_b = device.clone(); + + let ab_ghost = ab.ghost(); + let a_ghost = at.ghost(); + let b_ghost = bt.ghost(); + ta.add_backward_op(move |grads| { + grads.try_alloc_for(&ab_ghost)?; + grads.try_alloc_for(&a_ghost)?; + let (ab_grad, a_grad) = grads.mut_and_ref(&ab_ghost, &a_ghost); + device.backward(ax, &ab_ghost, ab_grad, &a_ghost, &b_ghost, AorB::A, a_grad) + }); + + let ab_ghost = ab.ghost(); + let a_ghost = at.ghost(); + let b_ghost = bt.ghost(); + tb.add_backward_op(move |grads| { + grads.try_alloc_for(&ab_ghost)?; + grads.try_alloc_for(&b_ghost)?; + let (ab_grad, b_grad) = grads.mut_and_ref(&ab_ghost, &b_ghost); + device_b.backward(ax, &ab_ghost, ab_grad, &a_ghost, &b_ghost, AorB::B, b_grad) + }); + + let at = at.put_tape(ta); + let bt = bt.put_tape(tb); + Ok((at, bt, tape)) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{tensor_ops::*, tests::*}; + + #[test] + fn test_split_ax_0() { + let dev: TestDevice = Default::default(); + let ab: Tensor, TestDtype, _> = dev.sample_normal(); + let ab_dyn = ab + .leaky_trace() + .try_realize::<(usize, Const<3>, Const<4>)>() + .unwrap(); + let (a, b, _tape) = ab_dyn.split_tensor_along(Axis::<0>, 2, 3); + let a = a.try_realize::<(Const<2>, Const<3>, Const<4>)>().unwrap(); + let b = b.try_realize::<(Const<3>, Const<3>, Const<4>)>().unwrap(); + let ab_arr = ab.array(); + let a_arr = a.array(); + let b_arr = b.array(); + println!("{a_arr:?}"); + println!("{b_arr:?}"); + println!("{ab_arr:?}"); + + assert_eq!(ab_arr[0], a_arr[0]); + assert_eq!(ab_arr[1], a_arr[1]); + assert_eq!(ab_arr[2], b_arr[0]); + assert_eq!(ab_arr[3], b_arr[1]); + assert_eq!(ab_arr[4], b_arr[2]); + + let ab_concat = (a, b).concat_tensor_along(Axis::<0>); + assert_eq!(ab.array(), ab_concat.array()); + let concat_grads = ab_concat.exp().sum().backward(); + let ab_grads = ab.leaky_trace().exp().sum().backward(); + + assert_close_to_tensor!(concat_grads.get(&ab), ab_grads.get(&ab)); + } + + #[test] + fn test_split_ax_1() { + let dev: TestDevice = Default::default(); + let ab: Tensor, TestDtype, _> = dev.sample_normal(); + let ab_dyn = ab + .leaky_trace() + .try_realize::<(Const<2>, usize, Const<4>)>() + .unwrap(); + let (a, b, _tape) = ab_dyn.split_tensor_along(Axis::<1>, 2, 3); + let a = a.try_realize::<(Const<2>, Const<2>, Const<4>)>().unwrap(); + let b = b.try_realize::<(Const<2>, Const<3>, Const<4>)>().unwrap(); + let ab_arr = ab.array(); + let a_arr = a.array(); + let b_arr = b.array(); + println!("{a_arr:?}"); + println!("{b_arr:?}"); + println!("{ab_arr:?}"); + + for i in 0..2 { + assert_eq!(ab_arr[i][0], a_arr[i][0]); + assert_eq!(ab_arr[i][1], a_arr[i][1]); + assert_eq!(ab_arr[i][2], b_arr[i][0]); + assert_eq!(ab_arr[i][3], b_arr[i][1]); + assert_eq!(ab_arr[i][4], b_arr[i][2]); + } + + let ab_concat = (a, b).concat_tensor_along(Axis::<1>); + assert_eq!(ab.array(), ab_concat.array()); + let concat_grads = ab_concat.exp().sum().backward(); + let ab_grads = ab.leaky_trace().exp().sum().backward(); + + println!("{:?}", concat_grads.get(&ab).array()); + println!("{:?}", ab_grads.get(&ab).array()); + + assert_close_to_tensor!(concat_grads.get(&ab), ab_grads.get(&ab)); + } + + #[test] + fn test_split_ax_2() { + let dev: TestDevice = Default::default(); + let ab: Tensor, TestDtype, _> = dev.sample_normal(); + let ab_dyn = ab + .leaky_trace() + .try_realize::<(Const<2>, Const<3>, usize)>() + .unwrap(); + let (a, b, _tape) = ab_dyn.split_tensor_along(Axis::<2>, 2, 3); + let a = a.try_realize::<(Const<2>, Const<3>, Const<2>)>().unwrap(); + let b = b.try_realize::<(Const<2>, Const<3>, Const<3>)>().unwrap(); + let ab_arr = ab.array(); + let a_arr = a.array(); + let b_arr = b.array(); + println!("{a_arr:?}"); + println!("{b_arr:?}"); + println!("{ab_arr:?}"); + + for i in 0..2 { + for j in 0..3 { + assert_eq!(ab_arr[i][j][0], a_arr[i][j][0]); + assert_eq!(ab_arr[i][j][1], a_arr[i][j][1]); + assert_eq!(ab_arr[i][j][2], b_arr[i][j][0]); + assert_eq!(ab_arr[i][j][3], b_arr[i][j][1]); + assert_eq!(ab_arr[i][j][4], b_arr[i][j][2]); + } + } + + let ab_concat = (a, b).concat_tensor_along(Axis::<2>); + assert_eq!(ab.array(), ab_concat.array()); + let concat_grads = ab_concat.exp().sum().backward(); + let ab_grads = ab.leaky_trace().exp().sum().backward(); + + println!("{:?}", concat_grads.get(&ab).array()); + println!("{:?}", ab_grads.get(&ab).array()); + + assert_close_to_tensor!(concat_grads.get(&ab), ab_grads.get(&ab)); + } +} diff --git a/dfdx-core/src/tensor_ops/split_tensor_along/webgpu_kernel.rs b/dfdx-core/src/tensor_ops/split_tensor_along/webgpu_kernel.rs new file mode 100644 index 00000000..be1923dd --- /dev/null +++ b/dfdx-core/src/tensor_ops/split_tensor_along/webgpu_kernel.rs @@ -0,0 +1,26 @@ +use crate::{shapes::*, tensor::*}; + +impl super::ConcatAlongKernel for Webgpu { + fn forward( + &self, + _ax: usize, + _ab: &Tensor, + _a: &mut Tensor, + _b: &mut Tensor, + ) -> Result<(), Error> { + todo!() + } + + fn backward( + &self, + _ax: usize, + _ab: &GhostTensor, + _grad_ab: &mut Self::Vec, + _a: &GhostTensor, + _b: &GhostTensor, + _a_or_b: AorB, + _grad_out: &Self::Vec, + ) -> Result<(), Error> { + todo!() + } +} diff --git a/dfdx-core/src/tensor_ops/utilities/device.rs b/dfdx-core/src/tensor_ops/utilities/device.rs index 91f87cf6..e7fbe641 100644 --- a/dfdx-core/src/tensor_ops/utilities/device.rs +++ b/dfdx-core/src/tensor_ops/utilities/device.rs @@ -21,6 +21,9 @@ pub trait Device: + super::super::concat_tensor_along::ConcatAlongKernel + super::super::concat_tensor_along::ConcatAlongKernel + // splits + + super::super::split_tensor_along::SplitAlongKernel + // optimizers + super::super::adam::AdamKernel + super::super::sgd::SgdKernel