Skip to content

Commit

Permalink
Merge branch 'unstack' into new-base
Browse files Browse the repository at this point in the history
  • Loading branch information
swfsql committed Mar 1, 2024
2 parents cd33ed7 + 664a907 commit 09d4237
Show file tree
Hide file tree
Showing 7 changed files with 401 additions and 0 deletions.
15 changes: 15 additions & 0 deletions dfdx-core/src/shapes/shape.rs
Original file line number Diff line number Diff line change
Expand Up @@ -145,18 +145,33 @@ where
pub trait Array<T>: IntoIterator<Item = T> {
type Dim: Dim;
fn dim(&self) -> Self::Dim;
fn from_fn<F>(cb: F, len: Self::Dim) -> Self
where
F: FnMut(usize) -> T;
}
impl<T, const N: usize> Array<T> for [T; N] {
type Dim = Const<N>;
fn dim(&self) -> Self::Dim {
Const
}
fn from_fn<F>(cb: F, _len: Self::Dim) -> Self
where
F: FnMut(usize) -> T,
{
std::array::from_fn(cb)
}
}
impl<T> Array<T> for std::vec::Vec<T> {
type Dim = usize;
fn dim(&self) -> Self::Dim {
self.len()
}
fn from_fn<F>(cb: F, len: Self::Dim) -> Self
where
F: FnMut(usize) -> T,
{
(0..len).map(cb).collect()
}
}

/// A collection of dimensions ([Dim]) that change how a multi-dimensional
Expand Down
2 changes: 2 additions & 0 deletions dfdx-core/src/tensor_ops/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,7 @@ mod sum_to;
mod tanh;
mod to_dtype;
mod tri;
mod unstack;
mod upscale2d;
mod var_to;

Expand Down Expand Up @@ -284,6 +285,7 @@ pub use sum_to::SumTo;
pub use tanh::tanh;
pub use to_dtype::{to_dtype, ToDtypeKernel};
pub use tri::{lower_tri, upper_tri};
pub use unstack::{SubDim, TryUnstack};
pub use upscale2d::{
Bilinear, GenericUpscale2D, NearestNeighbor, TryUpscale2D, Upscale2DKernel, UpscaleMethod,
};
Expand Down
63 changes: 63 additions & 0 deletions dfdx-core/src/tensor_ops/unstack/cpu_kernel.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
use crate::{
prelude::NoneTape,
shapes::*,
tensor::{unique_id, Cpu, Error, Tensor},
};

// note: in order to return NoneTape items and not require a tape type information T,
// each element must be optional.
impl<E: Dtype> super::UnstackKernel<E> for Cpu {
fn forward<S: Shape, OptionalItems>(
&self,
stack: Tensor<S, E, Self, NoneTape>,
) -> Result<OptionalItems, Error>
where
S: super::SubDim,
OptionalItems: Array<Option<Tensor<S::Tail, E, Self, NoneTape>>, Dim = S::Head>,
{
let (head, tail) = stack.shape().sub_dim();
let stack_data = stack.data.as_slice();
let unstack_num_elements = tail.num_elements();
Ok(OptionalItems::from_fn(
|i| {
let mut data = self
.try_alloc_elem(unstack_num_elements, E::default())
// TODO: remove unwrap (needs try_from_fn)
// https://github.com/rust-lang/rust/issues/89379
.unwrap();

data.copy_from_slice(
&stack_data[i * unstack_num_elements..(i + 1) * unstack_num_elements],
);

Some(Tensor {
id: unique_id(),
data: std::sync::Arc::new(data),
shape: *tail.shape(),
strides: tail.strides(),
device: self.clone(),
tape: NoneTape,
})
},
head,
))
}
fn backward(
&self,
grad_stack: &mut Self::Vec,
grad_unstack: &Self::Vec,
unstack_idx: usize,
) -> Result<(), Error> {
let unstack_num_elements = grad_unstack.len();
for (i, stacked) in grad_stack
.iter_mut()
.skip(unstack_idx * unstack_num_elements)
.take(unstack_num_elements)
.enumerate()
{
*stacked += grad_unstack[i];
}

Ok(())
}
}
27 changes: 27 additions & 0 deletions dfdx-core/src/tensor_ops/unstack/cuda_kernel.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
use crate::{
prelude::NoneTape,
shapes::*,
tensor::{Cuda, Error, Tensor},
};
use cudarc::types::CudaTypeName;

impl<E: Dtype + CudaTypeName> super::UnstackKernel<E> for Cuda {
fn forward<S: Shape, OptionalItems>(
&self,
_stack: Tensor<S, E, Self, NoneTape>,
) -> Result<OptionalItems, Error>
where
S: super::SubDim,
OptionalItems: Array<Option<Tensor<S::Tail, E, Self, NoneTape>>, Dim = S::Head>,
{
todo!()
}
fn backward(
&self,
_grad_stack: &mut Self::Vec,
_grad_unstack: &Self::Vec,
_unstack_idx: usize,
) -> Result<(), Error> {
todo!()
}
}
Loading

0 comments on commit 09d4237

Please sign in to comment.