Skip to content

Commit

Permalink
unstack fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
swfsql committed Feb 9, 2024
1 parent 5ffff2d commit c695a15
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions dfdx-core/src/tensor_ops/unstack/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use crate::{shapes::*, tensor::*};
use std::vec::Vec;

mod cpu_kernel;
#[cfg(feature = "cuda")]
Expand All @@ -21,15 +22,15 @@ mod webgpu_kernel;
/// # use dfdx_core::prelude::*;
/// # let dev: Cpu = Default::default();
/// let stack: Tensor<Rank3<2, 3, 4>, f32, _> = dev.zeros();
/// let [a, b]: [Tensor<Rank2<3, 4>, f32, _>; 2] = stack.unstack();
/// let ([a, b], _tape): ([Tensor<Rank2<3, 4>, f32, _>; 2], _) = stack.unstack();
/// ```
///
/// Unstacking to a vec:
/// ```rust
/// # use dfdx_core::prelude::*;
/// # let dev: Cpu = Default::default();
/// let stack: Tensor<(usize, Const::<3>, Const::<4>>, f32, _> = dev.zeros_like(&(2, Const, Const));
/// let unstack: Vec<Tensor<Rank2<3, 4>, f32, _>> = stack.unstack();
/// let stack: Tensor<(usize, Const::<3>, Const::<4>), f32, _> = dev.zeros_like(&(2, Const, Const));
/// let (unstack, _tape): (Vec<Tensor<Rank2<3, 4>, f32, _>>, _) = stack.unstack();
/// ```
pub trait TryUnstack<Head: Dim>: Sized {
type Unstacked;
Expand Down

0 comments on commit c695a15

Please sign in to comment.