diff --git a/examples/safetensors-save-load.rs b/examples/safetensors-save-load.rs index a16aee548..d9b8ec9b4 100644 --- a/examples/safetensors-save-load.rs +++ b/examples/safetensors-save-load.rs @@ -3,10 +3,7 @@ #[cfg(feature = "safetensors")] fn main() { use ::safetensors::SafeTensors; - use dfdx::{ - prelude::*, - tensor::{AsArray, AutoDevice, Cpu}, - }; + use dfdx::prelude::*; use memmap2::MmapOptions; let dev: Cpu = Default::default(); diff --git a/src/shapes/shape.rs b/src/shapes/shape.rs index a0130726b..1d900b339 100644 --- a/src/shapes/shape.rs +++ b/src/shapes/shape.rs @@ -1,5 +1,8 @@ use super::{axes::*, ReduceShape, ReduceShapeTo}; +#[cfg(feature = "f16")] +pub use half::f16; + #[cfg(not(feature = "cuda"))] pub trait SafeZeros {} @@ -48,7 +51,7 @@ unit!(u128, 1); unit!(i128, 1); unit!(bool, true); #[cfg(feature = "f16")] -unit!(half::f16, half::f16::ONE); +unit!(f16, f16::ONE); /// Represents something that has a [Unit]. pub trait HasUnitType { @@ -88,7 +91,7 @@ impl Dtype for u64 {} impl Dtype for u128 {} impl Dtype for usize {} #[cfg(feature = "f16")] -impl Dtype for half::f16 {} +impl Dtype for f16 {} /// Represents something that has a [Dtype]. pub trait HasDtype { diff --git a/src/tensor/cpu/device.rs b/src/tensor/cpu/device.rs index b5a81f840..ebc380dcd 100644 --- a/src/tensor/cpu/device.rs +++ b/src/tensor/cpu/device.rs @@ -220,7 +220,7 @@ impl Cache for Cpu { debug_assert_eq!(std::alloc::Layout::new::().align(), 4); debug_assert_eq!(std::alloc::Layout::new::().align(), 8); match key.alignment { - 1 => unsafe { drop(Vec::from_raw_parts(alloc.0 as *mut u8, len, cap)) }, + 1 => unsafe { drop(Vec::from_raw_parts(alloc.0, len, cap)) }, 2 => unsafe { drop(Vec::from_raw_parts(alloc.0 as *mut u16, len, cap)) }, 4 => unsafe { drop(Vec::from_raw_parts(alloc.0 as *mut u32, len, cap)) }, 8 => unsafe { drop(Vec::from_raw_parts(alloc.0 as *mut u64, len, cap)) }, diff --git a/src/tensor/cpu/index.rs b/src/tensor/cpu/index.rs index 3bd7a1f1e..ff7683bc6 100644 --- a/src/tensor/cpu/index.rs +++ b/src/tensor/cpu/index.rs @@ -11,11 +11,7 @@ pub(crate) fn index_to_i(shape: &S, strides: &S::Concrete, index: S::C panic!("Index out of bounds: index={index:?} shape={shape:?}"); } } - strides - .into_iter() - .zip(index.into_iter()) - .map(|(a, b)| a * b) - .sum() + strides.into_iter().zip(index).map(|(a, b)| a * b).sum() } impl std::ops::Index for Tensor { diff --git a/src/tensor/ghost.rs b/src/tensor/ghost.rs index 67cd9cce2..c05854ab6 100644 --- a/src/tensor/ghost.rs +++ b/src/tensor/ghost.rs @@ -17,7 +17,7 @@ pub struct GhostTensor> { impl, T> Tensor { /// Creates a ghost tensor that doesn't hold a reference /// to the tensor's data. - pub(crate) fn ghost(&self) -> GhostTensor { + pub fn ghost(&self) -> GhostTensor { GhostTensor { id: self.id, len: self.device.len(&self.data), diff --git a/src/tensor/gradients.rs b/src/tensor/gradients.rs index 317636245..dde7ba99f 100644 --- a/src/tensor/gradients.rs +++ b/src/tensor/gradients.rs @@ -44,20 +44,16 @@ impl> Gradients { impl> Gradients { /// Retrieves mutable gradient for `t`, allocating one if it isn't present. - pub(crate) fn get_or_alloc_mut( + pub fn get_or_alloc_mut( &mut self, - t: &Tensor, + t: &impl Tensorlike, ) -> Result<&mut D::Vec, D::Err> { - let ghost = t.ghost(); - self.try_alloc_for(&ghost)?; - Ok(self.get_mut(&ghost)) + self.try_alloc_for(t)?; + Ok(self.get_mut(t)) } /// Inserts a gradient for `t` - pub(crate) fn try_alloc_for( - &mut self, - t: &impl Tensorlike, - ) -> Result<(), D::Err> { + pub fn try_alloc_for(&mut self, t: &impl Tensorlike) -> Result<(), D::Err> { if let std::collections::btree_map::Entry::Vacant(e) = self.gradient_by_id.entry(t.id()) { e.insert(t.try_alloc_grad()?); } @@ -92,7 +88,7 @@ impl> Gradients { self.gradient_by_id.get_mut(&t.id()).unwrap() } - /// Returns a mutable reference to the data associated with `t`. + /// Returns an immutable reference to the data associated with `t`. /// /// **Panics** if data associated with `t` is not found. This indicates an unrecoverable bug. pub(crate) fn get_ref(&mut self, t: &impl Tensorlike) -> &D::Vec { @@ -104,14 +100,14 @@ impl> Gradients { /// # Panics /// If no data is associated with `t` yet, this will panic due to an unwrap() /// on a .get() to the underlying hashmap. - pub fn get(&self, t: &Tensor) -> Tensor { - let buf = self.gradient_by_id.get(&t.id).unwrap().clone(); + pub fn get(&self, t: &impl Tensorlike) -> Tensor { + let buf = self.gradient_by_id.get(&t.id()).unwrap().clone(); Tensor { id: unique_id(), data: std::sync::Arc::new(buf), - shape: t.shape, - strides: t.strides, - device: t.device.clone(), + shape: *t.shape(), + strides: t.strides(), + device: t.dev().clone(), tape: Default::default(), } } diff --git a/src/tensor_ops/matmul/cpu_kernel.rs b/src/tensor_ops/matmul/cpu_kernel.rs index 772329e69..ca205093a 100644 --- a/src/tensor_ops/matmul/cpu_kernel.rs +++ b/src/tensor_ops/matmul/cpu_kernel.rs @@ -66,6 +66,7 @@ impl MatMulImpl for Cpu { naive_gemm((m, k, n), accum, ap, astr, bp, bstr, cp, cstr); #[cfg(feature = "cpu")] + #[allow(clippy::unnecessary_cast)] unsafe { gemm::gemm( m.size(),