From 95c3d05b4631cd6171bc089f2a8dc18ce60e9108 Mon Sep 17 00:00:00 2001 From: Nathan Koppel Date: Wed, 28 Jun 2023 23:46:35 -0500 Subject: [PATCH] add scalars to safetensor serialization --- src/nn/safetensors.rs | 30 +++++++++++++++++++++++------- src/tensor/numpy.rs | 4 ++-- 2 files changed, 25 insertions(+), 9 deletions(-) diff --git a/src/nn/safetensors.rs b/src/nn/safetensors.rs index f6acf02f7..be3d5ac61 100644 --- a/src/nn/safetensors.rs +++ b/src/nn/safetensors.rs @@ -12,7 +12,6 @@ use safetensors::{ tensor::{Dtype as SDtype, SafeTensors, TensorView}, SafeTensorError, }; -use std::collections::BTreeMap; use super::tensor_collection::*; @@ -25,12 +24,12 @@ struct TensorData { } pub struct Writer { - tensors: BTreeMap, + tensors: Vec<(String, TensorData)>, } impl Writer { pub fn new() -> Self { - let tensors = BTreeMap::new(); + let tensors = Vec::new(); Self { tensors } } @@ -44,11 +43,11 @@ impl Writer { let data = tensor.as_vec(); let data: Vec = data.iter().flat_map(|f| f.to_le_bytes()).collect(); let tdata = TensorData { dtype, shape, data }; - self.tensors.insert(key, tdata); + self.tensors.push((key, tdata)); } pub fn save(&self, path: &Path) -> Result<(), SafeTensorError> { - let views: BTreeMap = self + let (names, views): (Vec, Vec) = self .tensors .iter() .map(|(k, tensor)| { @@ -57,8 +56,11 @@ impl Writer { TensorView::new(tensor.dtype, tensor.shape.clone(), &tensor.data).unwrap(), ) }) - .collect(); - serialize_to_file(&views, &None, path) + .unzip(); + + let data = names.into_iter().zip(views.iter()); + + serialize_to_file(data, &None, path) } } @@ -76,6 +78,20 @@ impl> TensorVisitor for Writer { self.add(full_path, t); Ok(None) } + + fn visit_scalar( + &mut self, + _: ScalarOptions, + (n, full_path): (&N, String), + ) -> Result, Self::Err> { + let data = TensorData { + dtype: safetensors::Dtype::F64, + shape: Vec::new(), + data: n.to_f64().unwrap().to_le_bytes().to_vec(), + }; + self.tensors.push((full_path, data)); + Ok(None) + } } /// Something that can be saved to a `.safetensors`. diff --git a/src/tensor/numpy.rs b/src/tensor/numpy.rs index 9d6ac77cb..a75404f3d 100644 --- a/src/tensor/numpy.rs +++ b/src/tensor/numpy.rs @@ -83,7 +83,7 @@ pub(crate) fn read_from_npz( read_from_npy(&mut f, shape) } -pub(crate) fn write_to_npy( +fn write_to_npy( w: &mut W, shape: &[usize], data: &[E], @@ -98,7 +98,7 @@ pub(crate) fn write_to_npy( Ok(()) } -pub(crate) fn read_from_npy( +fn read_from_npy( r: &mut R, shape: &[usize], ) -> Result, NpyError> {