Skip to content

Commit

Permalink
Split TryConcatAlong into two traits
Browse files Browse the repository at this point in the history
- Deprecated `TryConcatAlong` in favor of `TryConcatTensorAlong` or `TryConcatShapeAlong`.
- Created `concat_tensor_along/` and `concat_shape_along/`.
  - Copied relevant sections and files from `concat_along`, adjusting where necessary.
  - Moved `concat_along/` kernels to `concat_tensor_along/`.
- Adjusted the issue's integration test to the new trait, which runs successfully.
  • Loading branch information
swfsql committed Nov 17, 2023
1 parent b292254 commit 6d8ab56
Show file tree
Hide file tree
Showing 8 changed files with 395 additions and 38 deletions.
29 changes: 5 additions & 24 deletions dfdx-core/src/tensor_ops/concat_along/mod.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
use super::concat_tensor_along::ConcatAlongKernel;
use crate::{shapes::*, tensor::*};

mod cpu_kernel;
#[cfg(feature = "cuda")]
mod cuda_kernel;

/// Concatenate two tensors along a given axis.
///
/// **Pytorch equivalent** `torch.concat`.
Expand Down Expand Up @@ -46,6 +43,7 @@ mod cuda_kernel;
/// let b: Tensor<(Const<2>, usize), f32, _> = dev.zeros_like(&(Const, 4));
/// let _: Tensor<Rank2<2, 6>, f32, _> = (a, b).concat_along(Axis::<1>).realize();
/// ```
#[deprecated = "Use TryConcatTensorAlong or TryConcatShapeAlong instead"]
pub trait TryConcatAlong<Ax>: Sized {
type Output;

Expand All @@ -57,26 +55,7 @@ pub trait TryConcatAlong<Ax>: Sized {
fn try_concat_along(self, ax: Ax) -> Result<Self::Output, Error>;
}

pub trait ConcatAlongKernel<E: Dtype>: Storage<E> {
fn forward<A: Shape, B: Shape, C: Shape>(
&self,
ax: usize,
a: &Tensor<A, E, Self>,
b: &Tensor<B, E, Self>,
c: &mut Tensor<C, E, Self>,
) -> Result<(), Error>;

fn backward<A: Shape, B: Shape>(
&self,
ax: usize,
a: &GhostTensor<A, E, Self>,
grad_a: &mut Self::Vec,
b: &GhostTensor<B, E, Self>,
grad_b: &mut Self::Vec,
grad_out: &Self::Vec,
) -> Result<(), Error>;
}

#[allow(deprecated)]
impl<A, B, Ax, E: Dtype, D, T: Tape<E, D>, R: Tape<E, D>> TryConcatAlong<Ax>
for (Tensor<A, E, D, T>, Tensor<B, E, D, R>)
where
Expand Down Expand Up @@ -121,6 +100,7 @@ where

macro_rules! impl_concat {
($Ax:expr, $NumDims:expr, [$($Head:tt),*], [$($Tail:tt),*]) => {
#[allow(deprecated)]
impl<A: Dim, B: Dim, $($Head: Dim, )* $($Tail: Dim, )*> TryConcatAlong<Axis<$Ax>>
for (
($($Head, )* A, $($Tail, )*),
Expand Down Expand Up @@ -181,6 +161,7 @@ impl_concat!(4, 6, [D0, D1, D2, D3], [D5]);
impl_concat!(5, 6, [D0, D1, D2, D3, D4], []);

#[cfg(test)]
#[allow(deprecated)]
mod tests {
use super::*;
use crate::{tensor_ops::*, tests::*};
Expand Down
153 changes: 153 additions & 0 deletions dfdx-core/src/tensor_ops/concat_shape_along/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
use crate::{shapes::*, tensor::*};

/// Concatenate two shapes along a given axis.
///
/// # [Const] dims **requires nightly**
///
/// Along Axis 0:
/// ```ignore
/// # use dfdx_core::prelude::*;
/// # let dev: Cpu = Default::default();
/// let a: Rank2<3, 4> = (Const, Const);
/// let b: Rank2<3, 4> = (Const, Const);
/// let _: Rank2<6, 4> = (a, b).concat_shape_along(Axis::<0>);
/// ```
///
/// Along Axis 1:
/// ```ignore
/// # use dfdx_core::prelude::*;
/// # let dev: Cpu = Default::default();
/// let a: Rank2<3, 4> = (Const, Const);
/// let b: Rank2<3, 4> = (Const, Const);
/// let _: Rank2<3, 8> = (a, b).concat_shape_along(Axis::<1>);
/// ```
///
/// # [usize] dims
/// Along Axis 0:
/// ```rust
/// # use dfdx_core::prelude::*;
/// # let dev: Cpu = Default::default();
/// let a: (usize, Const<3>) = (2, Const);
/// let b: (usize, Const<3>) = (4, Const);
/// let c: (usize, Const<3>) = (a, b).concat_shape_along(Axis::<0>);
/// assert_eq!(c, (6, Const::<3>));
/// ```
///
/// Along Axis 1:
/// ```rust
/// # use dfdx_core::prelude::*;
/// # let dev: Cpu = Default::default();
/// let a: (Const<2>, usize) = (Const, 2);
/// let b: (Const<2>, usize) = (Const, 4);
/// let c: (Const<2>, usize) = (a, b).concat_shape_along(Axis::<1>);
/// assert_eq!(c, (Const::<2>, 6));
/// ```
pub trait TryConcatShapeAlong<Ax>: Sized {
type Output;

/// Concatenates self along the given axis.
fn concat_shape_along(self, ax: Ax) -> Self::Output {
self.try_concat_shape_along(ax).unwrap()
}
/// Fallibly concatenates self along the given axis.
fn try_concat_shape_along(self, ax: Ax) -> Result<Self::Output, Error>;
}

macro_rules! impl_concat {
($Ax:expr, $NumDims:expr, [$($Head:tt),*], [$($Tail:tt),*]) => {
impl<A: Dim, B: Dim, $($Head: Dim, )* $($Tail: Dim, )*> TryConcatShapeAlong<Axis<$Ax>>
for (
($($Head, )* A, $($Tail, )*),
($($Head, )* B, $($Tail, )*),
)
where
A: std::ops::Add<B>,
<A as std::ops::Add<B>>::Output: Dim,
{
type Output = (
$($Head, )*
<A as std::ops::Add<B>>::Output,
$($Tail, )*
);

fn try_concat_shape_along(self, _: Axis<$Ax>) -> Result<Self::Output, Error> {
let (lhs, rhs) = self;
let lhs_dims = lhs.concrete();
let rhs_dims = rhs.concrete();
for i in 0..$NumDims {
if i != $Ax {
assert_eq!(lhs_dims[i], rhs_dims[i]);
}
}
let mut out_dims = lhs_dims;
out_dims[$Ax] += rhs_dims[$Ax];
Ok(Self::Output::from_concrete(&out_dims).unwrap())
}
}
};
}

impl_concat!(0, 1, [], []);
impl_concat!(0, 2, [], [D1]);
impl_concat!(0, 3, [], [D1, D2]);
impl_concat!(0, 4, [], [D1, D2, D3]);
impl_concat!(0, 5, [], [D1, D2, D3, D4]);
impl_concat!(0, 6, [], [D1, D2, D3, D4, D5]);

impl_concat!(1, 2, [D0], []);
impl_concat!(1, 3, [D0], [D2]);
impl_concat!(1, 4, [D0], [D2, D3]);
impl_concat!(1, 5, [D0], [D2, D3, D4]);
impl_concat!(1, 6, [D0], [D2, D3, D4, D5]);

impl_concat!(2, 3, [D0, D1], []);
impl_concat!(2, 4, [D0, D1], [D3]);
impl_concat!(2, 5, [D0, D1], [D3, D4]);
impl_concat!(2, 6, [D0, D1], [D3, D4, D5]);

impl_concat!(3, 4, [D0, D1, D2], []);
impl_concat!(3, 5, [D0, D1, D2], [D4]);
impl_concat!(3, 6, [D0, D1, D2], [D4, D5]);

impl_concat!(4, 5, [D0, D1, D2, D3], []);
impl_concat!(4, 6, [D0, D1, D2, D3], [D5]);

impl_concat!(5, 6, [D0, D1, D2, D3, D4], []);

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_concat_shape() {
let a: (usize, Const<5>) = (5, Const);
let b: (usize, Const<5>) = (3, Const);
assert_eq!((a, b).concat_shape_along(Axis::<0>), (8, Const::<5>));

let a: (Const<5>, Const<5>) = (Const, Const);
let b: (usize, Const<5>) = (3, Const);
assert_eq!((a, b).concat_shape_along(Axis::<0>), (8, Const::<5>));

let a: (usize, Const<5>) = (5, Const);
let b: (Const<3>, Const<5>) = (Const, Const);
assert_eq!((a, b).concat_shape_along(Axis::<0>), (8, Const::<5>));

#[cfg(feature = "nightly")]
{
let a: (Const<5>, Const<5>) = (Const, Const);
let b: (Const<3>, Const<5>) = (Const, Const);
assert_eq!(
(a, b).concat_shape_along(Axis::<0>),
(Const::<8>, Const::<5>)
);
}
}

#[test]
#[should_panic = "left: 10\n right: 7"]
fn test_concat_shape_fails() {
let a = (5, 10);
let b = (3, 7);
(a, b).concat_shape_along(Axis::<0>);
}
}
Loading

0 comments on commit 6d8ab56

Please sign in to comment.