From 6d8ab56264bca092483e8f6caa375a3b10fd128b Mon Sep 17 00:00:00 2001 From: Thiago Machado Date: Thu, 16 Nov 2023 20:57:49 -0500 Subject: [PATCH] Split `TryConcatAlong` into two traits - Deprecated `TryConcatAlong` in favor of `TryConcatTensorAlong` or `TryConcatShapeAlong`. - Created `concat_tensor_along/` and `concat_shape_along/`. - Copied relevant sections and files from `concat_along`, adjusting where necessary. - Moved `concat_along/` kernels to `concat_tensor_along/`. - Adjusted the issue's integration test to the new trait, which runs successfully. --- dfdx-core/src/tensor_ops/concat_along/mod.rs | 29 +-- .../src/tensor_ops/concat_shape_along/mod.rs | 153 ++++++++++++ .../cpu_kernel.rs | 0 .../cuda_kernel.rs | 0 .../src/tensor_ops/concat_tensor_along/mod.rs | 226 ++++++++++++++++++ dfdx-core/src/tensor_ops/mod.rs | 5 + dfdx-core/src/tensor_ops/utilities/device.rs | 4 +- dfdx/tests/issue_tests.rs | 16 +- 8 files changed, 395 insertions(+), 38 deletions(-) create mode 100644 dfdx-core/src/tensor_ops/concat_shape_along/mod.rs rename dfdx-core/src/tensor_ops/{concat_along => concat_tensor_along}/cpu_kernel.rs (100%) rename dfdx-core/src/tensor_ops/{concat_along => concat_tensor_along}/cuda_kernel.rs (100%) create mode 100644 dfdx-core/src/tensor_ops/concat_tensor_along/mod.rs diff --git a/dfdx-core/src/tensor_ops/concat_along/mod.rs b/dfdx-core/src/tensor_ops/concat_along/mod.rs index 0c796d6e..b20cbc6e 100644 --- a/dfdx-core/src/tensor_ops/concat_along/mod.rs +++ b/dfdx-core/src/tensor_ops/concat_along/mod.rs @@ -1,9 +1,6 @@ +use super::concat_tensor_along::ConcatAlongKernel; use crate::{shapes::*, tensor::*}; -mod cpu_kernel; -#[cfg(feature = "cuda")] -mod cuda_kernel; - /// Concatenate two tensors along a given axis. /// /// **Pytorch equivalent** `torch.concat`. @@ -46,6 +43,7 @@ mod cuda_kernel; /// let b: Tensor<(Const<2>, usize), f32, _> = dev.zeros_like(&(Const, 4)); /// let _: Tensor, f32, _> = (a, b).concat_along(Axis::<1>).realize(); /// ``` +#[deprecated = "Use TryConcatTensorAlong or TryConcatShapeAlong instead"] pub trait TryConcatAlong: Sized { type Output; @@ -57,26 +55,7 @@ pub trait TryConcatAlong: Sized { fn try_concat_along(self, ax: Ax) -> Result; } -pub trait ConcatAlongKernel: Storage { - fn forward( - &self, - ax: usize, - a: &Tensor, - b: &Tensor, - c: &mut Tensor, - ) -> Result<(), Error>; - - fn backward( - &self, - ax: usize, - a: &GhostTensor, - grad_a: &mut Self::Vec, - b: &GhostTensor, - grad_b: &mut Self::Vec, - grad_out: &Self::Vec, - ) -> Result<(), Error>; -} - +#[allow(deprecated)] impl, R: Tape> TryConcatAlong for (Tensor, Tensor) where @@ -121,6 +100,7 @@ where macro_rules! impl_concat { ($Ax:expr, $NumDims:expr, [$($Head:tt),*], [$($Tail:tt),*]) => { + #[allow(deprecated)] impl TryConcatAlong> for ( ($($Head, )* A, $($Tail, )*), @@ -181,6 +161,7 @@ impl_concat!(4, 6, [D0, D1, D2, D3], [D5]); impl_concat!(5, 6, [D0, D1, D2, D3, D4], []); #[cfg(test)] +#[allow(deprecated)] mod tests { use super::*; use crate::{tensor_ops::*, tests::*}; diff --git a/dfdx-core/src/tensor_ops/concat_shape_along/mod.rs b/dfdx-core/src/tensor_ops/concat_shape_along/mod.rs new file mode 100644 index 00000000..6c2630c1 --- /dev/null +++ b/dfdx-core/src/tensor_ops/concat_shape_along/mod.rs @@ -0,0 +1,153 @@ +use crate::{shapes::*, tensor::*}; + +/// Concatenate two shapes along a given axis. +/// +/// # [Const] dims **requires nightly** +/// +/// Along Axis 0: +/// ```ignore +/// # use dfdx_core::prelude::*; +/// # let dev: Cpu = Default::default(); +/// let a: Rank2<3, 4> = (Const, Const); +/// let b: Rank2<3, 4> = (Const, Const); +/// let _: Rank2<6, 4> = (a, b).concat_shape_along(Axis::<0>); +/// ``` +/// +/// Along Axis 1: +/// ```ignore +/// # use dfdx_core::prelude::*; +/// # let dev: Cpu = Default::default(); +/// let a: Rank2<3, 4> = (Const, Const); +/// let b: Rank2<3, 4> = (Const, Const); +/// let _: Rank2<3, 8> = (a, b).concat_shape_along(Axis::<1>); +/// ``` +/// +/// # [usize] dims +/// Along Axis 0: +/// ```rust +/// # use dfdx_core::prelude::*; +/// # let dev: Cpu = Default::default(); +/// let a: (usize, Const<3>) = (2, Const); +/// let b: (usize, Const<3>) = (4, Const); +/// let c: (usize, Const<3>) = (a, b).concat_shape_along(Axis::<0>); +/// assert_eq!(c, (6, Const::<3>)); +/// ``` +/// +/// Along Axis 1: +/// ```rust +/// # use dfdx_core::prelude::*; +/// # let dev: Cpu = Default::default(); +/// let a: (Const<2>, usize) = (Const, 2); +/// let b: (Const<2>, usize) = (Const, 4); +/// let c: (Const<2>, usize) = (a, b).concat_shape_along(Axis::<1>); +/// assert_eq!(c, (Const::<2>, 6)); +/// ``` +pub trait TryConcatShapeAlong: Sized { + type Output; + + /// Concatenates self along the given axis. + fn concat_shape_along(self, ax: Ax) -> Self::Output { + self.try_concat_shape_along(ax).unwrap() + } + /// Fallibly concatenates self along the given axis. + fn try_concat_shape_along(self, ax: Ax) -> Result; +} + +macro_rules! impl_concat { + ($Ax:expr, $NumDims:expr, [$($Head:tt),*], [$($Tail:tt),*]) => { + impl TryConcatShapeAlong> + for ( + ($($Head, )* A, $($Tail, )*), + ($($Head, )* B, $($Tail, )*), + ) + where + A: std::ops::Add, + >::Output: Dim, + { + type Output = ( + $($Head, )* + >::Output, + $($Tail, )* + ); + + fn try_concat_shape_along(self, _: Axis<$Ax>) -> Result { + let (lhs, rhs) = self; + let lhs_dims = lhs.concrete(); + let rhs_dims = rhs.concrete(); + for i in 0..$NumDims { + if i != $Ax { + assert_eq!(lhs_dims[i], rhs_dims[i]); + } + } + let mut out_dims = lhs_dims; + out_dims[$Ax] += rhs_dims[$Ax]; + Ok(Self::Output::from_concrete(&out_dims).unwrap()) + } + } + }; +} + +impl_concat!(0, 1, [], []); +impl_concat!(0, 2, [], [D1]); +impl_concat!(0, 3, [], [D1, D2]); +impl_concat!(0, 4, [], [D1, D2, D3]); +impl_concat!(0, 5, [], [D1, D2, D3, D4]); +impl_concat!(0, 6, [], [D1, D2, D3, D4, D5]); + +impl_concat!(1, 2, [D0], []); +impl_concat!(1, 3, [D0], [D2]); +impl_concat!(1, 4, [D0], [D2, D3]); +impl_concat!(1, 5, [D0], [D2, D3, D4]); +impl_concat!(1, 6, [D0], [D2, D3, D4, D5]); + +impl_concat!(2, 3, [D0, D1], []); +impl_concat!(2, 4, [D0, D1], [D3]); +impl_concat!(2, 5, [D0, D1], [D3, D4]); +impl_concat!(2, 6, [D0, D1], [D3, D4, D5]); + +impl_concat!(3, 4, [D0, D1, D2], []); +impl_concat!(3, 5, [D0, D1, D2], [D4]); +impl_concat!(3, 6, [D0, D1, D2], [D4, D5]); + +impl_concat!(4, 5, [D0, D1, D2, D3], []); +impl_concat!(4, 6, [D0, D1, D2, D3], [D5]); + +impl_concat!(5, 6, [D0, D1, D2, D3, D4], []); + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_concat_shape() { + let a: (usize, Const<5>) = (5, Const); + let b: (usize, Const<5>) = (3, Const); + assert_eq!((a, b).concat_shape_along(Axis::<0>), (8, Const::<5>)); + + let a: (Const<5>, Const<5>) = (Const, Const); + let b: (usize, Const<5>) = (3, Const); + assert_eq!((a, b).concat_shape_along(Axis::<0>), (8, Const::<5>)); + + let a: (usize, Const<5>) = (5, Const); + let b: (Const<3>, Const<5>) = (Const, Const); + assert_eq!((a, b).concat_shape_along(Axis::<0>), (8, Const::<5>)); + + #[cfg(feature = "nightly")] + { + let a: (Const<5>, Const<5>) = (Const, Const); + let b: (Const<3>, Const<5>) = (Const, Const); + assert_eq!( + (a, b).concat_shape_along(Axis::<0>), + (Const::<8>, Const::<5>) + ); + } + } + + #[test] + #[should_panic = "left: 10\n right: 7"] + fn test_concat_shape_fails() { + let a = (5, 10); + let b = (3, 7); + (a, b).concat_shape_along(Axis::<0>); + } +} diff --git a/dfdx-core/src/tensor_ops/concat_along/cpu_kernel.rs b/dfdx-core/src/tensor_ops/concat_tensor_along/cpu_kernel.rs similarity index 100% rename from dfdx-core/src/tensor_ops/concat_along/cpu_kernel.rs rename to dfdx-core/src/tensor_ops/concat_tensor_along/cpu_kernel.rs diff --git a/dfdx-core/src/tensor_ops/concat_along/cuda_kernel.rs b/dfdx-core/src/tensor_ops/concat_tensor_along/cuda_kernel.rs similarity index 100% rename from dfdx-core/src/tensor_ops/concat_along/cuda_kernel.rs rename to dfdx-core/src/tensor_ops/concat_tensor_along/cuda_kernel.rs diff --git a/dfdx-core/src/tensor_ops/concat_tensor_along/mod.rs b/dfdx-core/src/tensor_ops/concat_tensor_along/mod.rs new file mode 100644 index 00000000..8e980739 --- /dev/null +++ b/dfdx-core/src/tensor_ops/concat_tensor_along/mod.rs @@ -0,0 +1,226 @@ +use super::concat_shape_along::TryConcatShapeAlong; +use crate::{shapes::*, tensor::*}; + +pub(crate) mod cpu_kernel; +#[cfg(feature = "cuda")] +pub(crate) mod cuda_kernel; + +/// Concatenate two tensors along a given axis. +/// +/// **Pytorch equivalent** `torch.concat`. +/// +/// # [Const] dims **requires nightly** +/// +/// Along Axis 0: +/// ```ignore +/// # use dfdx_core::prelude::*; +/// # let dev: Cpu = Default::default(); +/// let a: Tensor, f32, _> = dev.zeros(); +/// let b: Tensor, f32, _> = dev.zeros(); +/// let _: Tensor, f32, _> = (a, b).concat_along(Axis::<0>); +/// ``` +/// +/// Along Axis 1: +/// ```ignore +/// # use dfdx_core::prelude::*; +/// # let dev: Cpu = Default::default(); +/// let a: Tensor, f32, _> = dev.zeros(); +/// let b: Tensor, f32, _> = dev.zeros(); +/// let _: Tensor, f32, _> = (a, b).concat_along(Axis::<1>); +/// ``` +/// +/// # [usize] dims +/// Along Axis 0: +/// ```rust +/// # use dfdx_core::prelude::*; +/// # let dev: Cpu = Default::default(); +/// let a: Tensor<(usize, Const<3>), f32, _> = dev.zeros_like(&(2, Const)); +/// let b: Tensor<(usize, Const<3>), f32, _> = dev.zeros_like(&(4, Const)); +/// let _: Tensor, f32, _> = (a, b).concat_along(Axis::<0>).realize(); +/// ``` +/// +/// Along Axis 1: +/// ```rust +/// # use dfdx_core::prelude::*; +/// # let dev: Cpu = Default::default(); +/// let a: Tensor<(Const<2>, usize), f32, _> = dev.zeros_like(&(Const, 2)); +/// let b: Tensor<(Const<2>, usize), f32, _> = dev.zeros_like(&(Const, 4)); +/// let _: Tensor, f32, _> = (a, b).concat_along(Axis::<1>).realize(); +/// ``` +pub trait TryConcatTensorAlong: Sized { + type Output; + + /// Concatenates self along the given axis. + fn concat_tensor_along(self, ax: Ax) -> Self::Output { + self.try_concat_tensor_along(ax).unwrap() + } + /// Fallibly concatenates self along the given axis. + fn try_concat_tensor_along(self, ax: Ax) -> Result; +} + +pub trait ConcatAlongKernel: Storage { + fn forward( + &self, + ax: usize, + a: &Tensor, + b: &Tensor, + c: &mut Tensor, + ) -> Result<(), Error>; + + fn backward( + &self, + ax: usize, + a: &GhostTensor, + grad_a: &mut Self::Vec, + b: &GhostTensor, + grad_b: &mut Self::Vec, + grad_out: &Self::Vec, + ) -> Result<(), Error>; +} + +impl, R: Tape> TryConcatTensorAlong + for (Tensor, Tensor) +where + Ax: Axes, + D: ConcatAlongKernel + ZerosTensor, + A: Shape + HasAxes, + B: Shape + HasAxes, + (A, B): TryConcatShapeAlong, + <(A, B) as TryConcatShapeAlong>::Output: Shape, + T: Merge, +{ + type Output = Tensor<<(A, B) as TryConcatShapeAlong>::Output, E, D, T>; + + fn try_concat_tensor_along(self, ax: Ax) -> Result { + let (lhs, rhs) = self; + + let out_shape = (*lhs.shape(), *rhs.shape()).concat_shape_along(ax); + let ax = Ax::as_array()[0] as usize; + + let (lhs, tape) = lhs.split_tape(); + let (rhs, rtape) = rhs.split_tape(); + let mut tape = tape.merge(rtape); + + let mut out = lhs.device.try_zeros_like(&out_shape)?; + lhs.device.forward(ax, &lhs, &rhs, &mut out)?; + + let lhs_ghost = lhs.ghost(); + let rhs_ghost = rhs.ghost(); + let out_ghost = out.ghost(); + tape.add_backward_op(move |grads| { + grads.try_alloc_for(&lhs_ghost)?; + grads.try_alloc_for(&rhs_ghost)?; + grads.try_alloc_for(&out_ghost)?; + let (lhs_grad, rhs_grad, out_grad) = + grads.muts_and_ref(&lhs_ghost, &rhs_ghost, &out_ghost); + lhs.device + .backward(ax, &lhs_ghost, lhs_grad, &rhs_ghost, rhs_grad, out_grad) + }); + Ok(out.put_tape(tape)) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{tensor_ops::*, tests::*}; + + #[test] + fn test_concat_ax_0() { + let dev: TestDevice = Default::default(); + let a: Tensor, TestDtype, _> = dev.sample_normal(); + let b: Tensor, TestDtype, _> = dev.sample_normal(); + let a_dyn = a + .leaky_trace() + .try_realize::<(usize, Const<3>, Const<4>)>() + .unwrap(); + let b_dyn = b + .clone() + .try_realize::<(usize, Const<3>, Const<4>)>() + .unwrap(); + let c = (a_dyn, b_dyn).concat_tensor_along(Axis::<0>); + let c = c.try_realize::<(Const<5>, Const<3>, Const<4>)>().unwrap(); + let a_arr = a.array(); + let b_arr = b.array(); + let c_arr = c.array(); + println!("{a_arr:?}"); + println!("{b_arr:?}"); + println!("{c_arr:?}"); + assert_eq!(c_arr[0], a_arr[0]); + assert_eq!(c_arr[1], a_arr[1]); + assert_eq!(c_arr[2], b_arr[0]); + assert_eq!(c_arr[3], b_arr[1]); + assert_eq!(c_arr[4], b_arr[2]); + let concat_grads = c.exp().sum().backward(); + let a_grads = a.leaky_trace().exp().sum().backward(); + let b_grads = b.leaky_trace().exp().sum().backward(); + assert_close_to_tensor!(concat_grads.get(&a), a_grads.get(&a)); + assert_close_to_tensor!(concat_grads.get(&b), b_grads.get(&b)); + } + + #[test] + fn test_concat_ax_1() { + let dev: TestDevice = Default::default(); + let a: Tensor, TestDtype, _> = dev.sample_normal(); + let b: Tensor, TestDtype, _> = dev.sample_normal(); + let a_dyn = a + .leaky_trace() + .try_realize::<(Const<2>, usize, Const<4>)>() + .unwrap(); + let b_dyn = b + .clone() + .try_realize::<(Const<2>, usize, Const<4>)>() + .unwrap(); + let c = (a_dyn, b_dyn).concat_tensor_along(Axis::<1>); + let c = c.try_realize::<(Const<2>, Const<5>, Const<4>)>().unwrap(); + let a_arr = a.array(); + let b_arr = b.array(); + let c_arr = c.array(); + for i in 0..2 { + assert_eq!(c_arr[i][0], a_arr[i][0]); + assert_eq!(c_arr[i][1], a_arr[i][1]); + assert_eq!(c_arr[i][2], b_arr[i][0]); + assert_eq!(c_arr[i][3], b_arr[i][1]); + assert_eq!(c_arr[i][4], b_arr[i][2]); + } + let concat_grads = c.exp().sum().backward(); + let a_grads = a.leaky_trace().exp().sum().backward(); + let b_grads = b.leaky_trace().exp().sum().backward(); + assert_close_to_tensor!(concat_grads.get(&a), a_grads.get(&a)); + assert_close_to_tensor!(concat_grads.get(&b), b_grads.get(&b)); + } + + #[test] + fn test_concat_ax_2() { + let dev: TestDevice = Default::default(); + let a: Tensor, TestDtype, _> = dev.sample_normal(); + let b: Tensor, TestDtype, _> = dev.sample_normal(); + let a_dyn = a + .leaky_trace() + .try_realize::<(Const<2>, Const<3>, usize)>() + .unwrap(); + let b_dyn = b + .clone() + .try_realize::<(Const<2>, Const<3>, usize)>() + .unwrap(); + let c = (a_dyn, b_dyn).concat_tensor_along(Axis::<2>); + let c = c.try_realize::<(Const<2>, Const<3>, Const<5>)>().unwrap(); + let a_arr = a.array(); + let b_arr = b.array(); + let c_arr = c.array(); + for i in 0..2 { + for j in 0..3 { + assert_eq!(c_arr[i][j][0], a_arr[i][j][0]); + assert_eq!(c_arr[i][j][1], a_arr[i][j][1]); + assert_eq!(c_arr[i][j][2], b_arr[i][j][0]); + assert_eq!(c_arr[i][j][3], b_arr[i][j][1]); + assert_eq!(c_arr[i][j][4], b_arr[i][j][2]); + } + } + let concat_grads = c.exp().sum().backward(); + let a_grads = a.leaky_trace().exp().sum().backward(); + let b_grads = b.leaky_trace().exp().sum().backward(); + assert_close_to_tensor!(concat_grads.get(&a), a_grads.get(&a)); + assert_close_to_tensor!(concat_grads.get(&b), b_grads.get(&b)); + } +} diff --git a/dfdx-core/src/tensor_ops/mod.rs b/dfdx-core/src/tensor_ops/mod.rs index c51040ee..d934b678 100644 --- a/dfdx-core/src/tensor_ops/mod.rs +++ b/dfdx-core/src/tensor_ops/mod.rs @@ -163,6 +163,8 @@ mod clamp; mod cmp; mod concat; mod concat_along; +mod concat_shape_along; +mod concat_tensor_along; mod cos; mod div; mod dropout; @@ -224,7 +226,10 @@ pub use clamp::clamp; pub use cmp::{eq, ge, gt, le, lt, ne, TryEq, TryGe, TryGt, TryLe, TryLt, TryNe}; #[allow(deprecated)] pub use concat::TryConcat; +#[allow(deprecated)] pub use concat_along::TryConcatAlong; +pub use concat_shape_along::TryConcatShapeAlong; +pub use concat_tensor_along::TryConcatTensorAlong; pub use cos::cos; pub use div::{div, TryDiv}; pub use dropout::dropout; diff --git a/dfdx-core/src/tensor_ops/utilities/device.rs b/dfdx-core/src/tensor_ops/utilities/device.rs index 2504185f..7e5740d5 100644 --- a/dfdx-core/src/tensor_ops/utilities/device.rs +++ b/dfdx-core/src/tensor_ops/utilities/device.rs @@ -18,8 +18,8 @@ pub trait Device: + super::super::stack::StackKernel + super::super::concat::ConcatKernel + super::super::concat::ConcatKernel - + super::super::concat_along::ConcatAlongKernel - + super::super::concat_along::ConcatAlongKernel + + super::super::concat_tensor_along::ConcatAlongKernel + + super::super::concat_tensor_along::ConcatAlongKernel // optimizers + super::super::adam::AdamKernel diff --git a/dfdx/tests/issue_tests.rs b/dfdx/tests/issue_tests.rs index 069f46fb..23f53e71 100644 --- a/dfdx/tests/issue_tests.rs +++ b/dfdx/tests/issue_tests.rs @@ -18,12 +18,12 @@ fn test_issue_891() { impl Module for ConcatTensorAlong> where - Input: TryConcatAlong>, + Input: TryConcatTensorAlong>, { - type Output = >>::Output; + type Output = >>::Output; fn try_forward(&self, x: Input) -> Result { - x.try_concat_along(Axis) + x.try_concat_tensor_along(Axis) } } @@ -32,13 +32,5 @@ fn test_issue_891() { let dev = Cpu::default(); let x = dev.tensor([1.]); let m = dev.build_module::(Arch::default()); - let y = m.forward(x); - /* - error[E0275]: overflow evaluating the requirement `((_, _, _, _), (..., ..., ..., ...)): dfdx::prelude::TryConcatAlong<...>` - --> dfdx/tests/issue_tests.rs:36:15 - | - 36 | let y = m.forward(x); - | ^^^^^^^ - | - */ + let _y: Tensor, _, _, _> = m.forward(x); }