diff --git a/Cargo.toml b/Cargo.toml index 127efeeb9..130d73df0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "dfdx" -version = "0.11.2" +version = "0.12.1" edition = "2021" license = "MIT OR Apache-2.0" rust-version = "1.65" diff --git a/README.md b/README.md index 9079bcf1b..a51889ed5 100644 --- a/README.md +++ b/README.md @@ -20,7 +20,7 @@ Features at a glance: `dfdx` is on [crates.io](https://crates.io/crates/dfdx)! Use by adding this to your `Cargo.toml`: ```toml -dfdx = "0.11.2" +dfdx = "0.12.1" ``` See the documentation at [docs.rs/dfdx](https://docs.rs/dfdx). diff --git a/examples/safetensors-save-load.rs b/examples/safetensors-save-load.rs index a16aee548..d9b8ec9b4 100644 --- a/examples/safetensors-save-load.rs +++ b/examples/safetensors-save-load.rs @@ -3,10 +3,7 @@ #[cfg(feature = "safetensors")] fn main() { use ::safetensors::SafeTensors; - use dfdx::{ - prelude::*, - tensor::{AsArray, AutoDevice, Cpu}, - }; + use dfdx::prelude::*; use memmap2::MmapOptions; let dev: Cpu = Default::default(); diff --git a/src/nn/npz.rs b/src/nn/npz.rs index ae0cb574c..3b3faa9ae 100644 --- a/src/nn/npz.rs +++ b/src/nn/npz.rs @@ -1,4 +1,5 @@ use crate::{ + prelude::numpy::NpyError, shapes::{Dtype, Shape}, tensor::{ numpy::{read_from_npz, write_to_npz, NpzError, NumpyDtype}, @@ -162,18 +163,26 @@ impl> TensorVisitor fn visit_scalar( &mut self, - _opts: ScalarOptions, + opts: ScalarOptions, (n, full_path): (&mut N, String), ) -> Result, Self::Err> { - let buf: Vec = read_from_npz(self, &[], full_path)?; - *n = N::from(buf[0]).unwrap_or_else(|| { - panic!( - "Failed to convert f64 value {} to {} when reading from npz!", - buf[0], - std::any::type_name::() - ) - }); - Ok(None) + match read_from_npz::<_, f64>(self, &[], full_path) { + Ok(buf) => { + *n = N::from(buf[0]).unwrap_or_else(|| { + panic!( + "Failed to convert f64 value {} to {} when reading from npz!", + buf[0], + std::any::type_name::() + ) + }); + Ok(None) + } + Err(NpyError::IoError(e)) if e.kind() == std::io::ErrorKind::NotFound => { + *n = opts.default; + Ok(None) + } + Err(x) => Err(x.into()), + } } } diff --git a/src/nn/safetensors.rs b/src/nn/safetensors.rs index 46f04a80d..5055dd448 100644 --- a/src/nn/safetensors.rs +++ b/src/nn/safetensors.rs @@ -171,20 +171,29 @@ impl<'data, E: Dtype + SafeDtype, D: Device> TensorVisitor for SafeTens fn visit_scalar( &mut self, - _: ScalarOptions, + opts: ScalarOptions, (n, full_path): (&mut N, String), ) -> Result, Self::Err> { - let data = self.tensor(&full_path)?.data(); - let mut array = [0; 8]; - array.copy_from_slice(data); - let val = f64::from_le_bytes(array); - *n = N::from(val).unwrap_or_else(|| { - panic!( - "Failed to convert f64 value {val} at {full_path} to {} when reading from safetensors!", - std::any::type_name::() - ) - }); - Ok(None) + match self.tensor(&full_path) { + Ok(tensor) => { + let data = tensor.data(); + let mut array = [0; 8]; + array.copy_from_slice(data); + let val = f64::from_le_bytes(array); + *n = N::from(val).unwrap_or_else(|| { + panic!( + "Failed to convert f64 value {val} at {full_path} to {} when reading from safetensors!", + std::any::type_name::() + ) + }); + Ok(None) + } + Err(SafeTensorError::TensorNotFound(_)) => { + *n = opts.default; + Ok(None) + } + Err(x) => Err(Error::SafeTensorError(x)), + } } } diff --git a/src/shapes/mod.rs b/src/shapes/mod.rs index 0168da817..4c0cf0009 100644 --- a/src/shapes/mod.rs +++ b/src/shapes/mod.rs @@ -18,18 +18,17 @@ mod same_numel; mod shape; mod slice; -pub(crate) use axes::Axes; -pub(crate) use broadcasts::{ +pub use broadcasts::{ BroadcastShapeTo, BroadcastStridesTo, ReduceShape, ReduceShapeTo, ReduceStridesTo, }; -pub(crate) use permutes::{PermuteShapeTo, PermuteStridesTo}; -pub(crate) use realize::RealizeShapeTo; -pub(crate) use replace_dim::{RemoveDimTo, ReplaceDimTo}; +pub use permutes::{PermuteShapeTo, PermuteStridesTo}; +pub use realize::RealizeShapeTo; +pub use replace_dim::{RemoveDimTo, ReplaceDimTo}; -pub(crate) use same_numel::AssertSameNumel; -pub(crate) use slice::SliceShape; +pub use same_numel::AssertSameNumel; +pub use slice::SliceShape; -pub use axes::{Axes2, Axes3, Axes4, Axes5, Axes6, Axis, HasAxes}; +pub use axes::{Axes, Axes2, Axes3, Axes4, Axes5, Axes6, Axis, HasAxes}; pub use shape::{Array, Const, ConstDim, Dim}; pub use shape::{ConstShape, HasShape, Shape}; pub use shape::{Dtype, HasDtype, HasUnitType, Unit}; diff --git a/src/shapes/shape.rs b/src/shapes/shape.rs index a0130726b..1d900b339 100644 --- a/src/shapes/shape.rs +++ b/src/shapes/shape.rs @@ -1,5 +1,8 @@ use super::{axes::*, ReduceShape, ReduceShapeTo}; +#[cfg(feature = "f16")] +pub use half::f16; + #[cfg(not(feature = "cuda"))] pub trait SafeZeros {} @@ -48,7 +51,7 @@ unit!(u128, 1); unit!(i128, 1); unit!(bool, true); #[cfg(feature = "f16")] -unit!(half::f16, half::f16::ONE); +unit!(f16, f16::ONE); /// Represents something that has a [Unit]. pub trait HasUnitType { @@ -88,7 +91,7 @@ impl Dtype for u64 {} impl Dtype for u128 {} impl Dtype for usize {} #[cfg(feature = "f16")] -impl Dtype for half::f16 {} +impl Dtype for f16 {} /// Represents something that has a [Dtype]. pub trait HasDtype { diff --git a/src/tensor/cpu/device.rs b/src/tensor/cpu/device.rs index b5a81f840..ebc380dcd 100644 --- a/src/tensor/cpu/device.rs +++ b/src/tensor/cpu/device.rs @@ -220,7 +220,7 @@ impl Cache for Cpu { debug_assert_eq!(std::alloc::Layout::new::().align(), 4); debug_assert_eq!(std::alloc::Layout::new::().align(), 8); match key.alignment { - 1 => unsafe { drop(Vec::from_raw_parts(alloc.0 as *mut u8, len, cap)) }, + 1 => unsafe { drop(Vec::from_raw_parts(alloc.0, len, cap)) }, 2 => unsafe { drop(Vec::from_raw_parts(alloc.0 as *mut u16, len, cap)) }, 4 => unsafe { drop(Vec::from_raw_parts(alloc.0 as *mut u32, len, cap)) }, 8 => unsafe { drop(Vec::from_raw_parts(alloc.0 as *mut u64, len, cap)) }, diff --git a/src/tensor/cpu/index.rs b/src/tensor/cpu/index.rs index 3bd7a1f1e..ff7683bc6 100644 --- a/src/tensor/cpu/index.rs +++ b/src/tensor/cpu/index.rs @@ -11,11 +11,7 @@ pub(crate) fn index_to_i(shape: &S, strides: &S::Concrete, index: S::C panic!("Index out of bounds: index={index:?} shape={shape:?}"); } } - strides - .into_iter() - .zip(index.into_iter()) - .map(|(a, b)| a * b) - .sum() + strides.into_iter().zip(index).map(|(a, b)| a * b).sum() } impl std::ops::Index for Tensor { diff --git a/src/tensor/ghost.rs b/src/tensor/ghost.rs index 67cd9cce2..c05854ab6 100644 --- a/src/tensor/ghost.rs +++ b/src/tensor/ghost.rs @@ -17,7 +17,7 @@ pub struct GhostTensor> { impl, T> Tensor { /// Creates a ghost tensor that doesn't hold a reference /// to the tensor's data. - pub(crate) fn ghost(&self) -> GhostTensor { + pub fn ghost(&self) -> GhostTensor { GhostTensor { id: self.id, len: self.device.len(&self.data), diff --git a/src/tensor/gradients.rs b/src/tensor/gradients.rs index 317636245..dde7ba99f 100644 --- a/src/tensor/gradients.rs +++ b/src/tensor/gradients.rs @@ -44,20 +44,16 @@ impl> Gradients { impl> Gradients { /// Retrieves mutable gradient for `t`, allocating one if it isn't present. - pub(crate) fn get_or_alloc_mut( + pub fn get_or_alloc_mut( &mut self, - t: &Tensor, + t: &impl Tensorlike, ) -> Result<&mut D::Vec, D::Err> { - let ghost = t.ghost(); - self.try_alloc_for(&ghost)?; - Ok(self.get_mut(&ghost)) + self.try_alloc_for(t)?; + Ok(self.get_mut(t)) } /// Inserts a gradient for `t` - pub(crate) fn try_alloc_for( - &mut self, - t: &impl Tensorlike, - ) -> Result<(), D::Err> { + pub fn try_alloc_for(&mut self, t: &impl Tensorlike) -> Result<(), D::Err> { if let std::collections::btree_map::Entry::Vacant(e) = self.gradient_by_id.entry(t.id()) { e.insert(t.try_alloc_grad()?); } @@ -92,7 +88,7 @@ impl> Gradients { self.gradient_by_id.get_mut(&t.id()).unwrap() } - /// Returns a mutable reference to the data associated with `t`. + /// Returns an immutable reference to the data associated with `t`. /// /// **Panics** if data associated with `t` is not found. This indicates an unrecoverable bug. pub(crate) fn get_ref(&mut self, t: &impl Tensorlike) -> &D::Vec { @@ -104,14 +100,14 @@ impl> Gradients { /// # Panics /// If no data is associated with `t` yet, this will panic due to an unwrap() /// on a .get() to the underlying hashmap. - pub fn get(&self, t: &Tensor) -> Tensor { - let buf = self.gradient_by_id.get(&t.id).unwrap().clone(); + pub fn get(&self, t: &impl Tensorlike) -> Tensor { + let buf = self.gradient_by_id.get(&t.id()).unwrap().clone(); Tensor { id: unique_id(), data: std::sync::Arc::new(buf), - shape: t.shape, - strides: t.strides, - device: t.device.clone(), + shape: *t.shape(), + strides: t.strides(), + device: t.dev().clone(), tape: Default::default(), } } diff --git a/src/tensor/mod.rs b/src/tensor/mod.rs index 908249b61..b7a93061d 100644 --- a/src/tensor/mod.rs +++ b/src/tensor/mod.rs @@ -170,7 +170,7 @@ pub use cuda::{Cuda, CudaError}; #[cfg(feature = "cuda")] pub type AutoDevice = Cuda; -pub use storage_traits::{AsArray, CopySlice, TensorFrom, TensorFromVec}; +pub use storage_traits::{AsArray, CopySlice, TensorFrom, TensorFromVec, TensorToArray}; pub use storage_traits::{Cache, HasErr, RandomU64, Storage, Synchronize}; pub use storage_traits::{OnesTensor, SampleTensor, TriangleTensor, ZerosTensor}; diff --git a/src/tensor/numpy.rs b/src/tensor/numpy.rs index a75404f3d..b362be2ad 100644 --- a/src/tensor/numpy.rs +++ b/src/tensor/numpy.rs @@ -76,9 +76,16 @@ pub(crate) fn read_from_npz( filename.push_str(".npy"); } - let mut f = r - .by_name(&filename) - .unwrap_or_else(|_| panic!("'{filename}' not found")); + let mut f = match r.by_name(&filename) { + Ok(f) => f, + Err(ZipError::FileNotFound) => { + return Err(NpyError::IoError(io::Error::new( + io::ErrorKind::NotFound, + ZipError::FileNotFound, + ))) + } + Err(e) => panic!("Uncaught zip error: {e}"), + }; read_from_npy(&mut f, shape) } diff --git a/src/tensor_ops/matmul/cpu_kernel.rs b/src/tensor_ops/matmul/cpu_kernel.rs index 772329e69..ca205093a 100644 --- a/src/tensor_ops/matmul/cpu_kernel.rs +++ b/src/tensor_ops/matmul/cpu_kernel.rs @@ -66,6 +66,7 @@ impl MatMulImpl for Cpu { naive_gemm((m, k, n), accum, ap, astr, bp, bstr, cp, cstr); #[cfg(feature = "cpu")] + #[allow(clippy::unnecessary_cast)] unsafe { gemm::gemm( m.size(),