From c695a15eb3472157ea4467884e33b69dce0756a6 Mon Sep 17 00:00:00 2001 From: Thiago Machado Date: Fri, 9 Feb 2024 12:37:46 -0500 Subject: [PATCH] unstack fixes --- dfdx-core/src/tensor_ops/unstack/mod.rs | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/dfdx-core/src/tensor_ops/unstack/mod.rs b/dfdx-core/src/tensor_ops/unstack/mod.rs index 21204f99..528ee1b7 100644 --- a/dfdx-core/src/tensor_ops/unstack/mod.rs +++ b/dfdx-core/src/tensor_ops/unstack/mod.rs @@ -1,4 +1,5 @@ use crate::{shapes::*, tensor::*}; +use std::vec::Vec; mod cpu_kernel; #[cfg(feature = "cuda")] @@ -21,15 +22,15 @@ mod webgpu_kernel; /// # use dfdx_core::prelude::*; /// # let dev: Cpu = Default::default(); /// let stack: Tensor, f32, _> = dev.zeros(); -/// let [a, b]: [Tensor, f32, _>; 2] = stack.unstack(); +/// let ([a, b], _tape): ([Tensor, f32, _>; 2], _) = stack.unstack(); /// ``` /// /// Unstacking to a vec: /// ```rust /// # use dfdx_core::prelude::*; /// # let dev: Cpu = Default::default(); -/// let stack: Tensor<(usize, Const::<3>, Const::<4>>, f32, _> = dev.zeros_like(&(2, Const, Const)); -/// let unstack: Vec, f32, _>> = stack.unstack(); +/// let stack: Tensor<(usize, Const::<3>, Const::<4>), f32, _> = dev.zeros_like(&(2, Const, Const)); +/// let (unstack, _tape): (Vec, f32, _>>, _) = stack.unstack(); /// ``` pub trait TryUnstack: Sized { type Unstacked;