Skip to content

Commit

Permalink
add scalars to safetensor serialization
Browse files Browse the repository at this point in the history
  • Loading branch information
nkoppel committed Jun 29, 2023
1 parent 6f0092a commit 95c3d05
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 9 deletions.
30 changes: 23 additions & 7 deletions src/nn/safetensors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ use safetensors::{
tensor::{Dtype as SDtype, SafeTensors, TensorView},
SafeTensorError,
};
use std::collections::BTreeMap;

use super::tensor_collection::*;

Expand All @@ -25,12 +24,12 @@ struct TensorData {
}

pub struct Writer {
tensors: BTreeMap<String, TensorData>,
tensors: Vec<(String, TensorData)>,
}

impl Writer {
pub fn new() -> Self {
let tensors = BTreeMap::new();
let tensors = Vec::new();
Self { tensors }
}

Expand All @@ -44,11 +43,11 @@ impl Writer {
let data = tensor.as_vec();
let data: Vec<u8> = data.iter().flat_map(|f| f.to_le_bytes()).collect();
let tdata = TensorData { dtype, shape, data };
self.tensors.insert(key, tdata);
self.tensors.push((key, tdata));
}

pub fn save(&self, path: &Path) -> Result<(), SafeTensorError> {
let views: BTreeMap<String, TensorView> = self
let (names, views): (Vec<String>, Vec<TensorView>) = self
.tensors
.iter()
.map(|(k, tensor)| {
Expand All @@ -57,8 +56,11 @@ impl Writer {
TensorView::new(tensor.dtype, tensor.shape.clone(), &tensor.data).unwrap(),
)
})
.collect();
serialize_to_file(&views, &None, path)
.unzip();

let data = names.into_iter().zip(views.iter());

serialize_to_file(data, &None, path)
}
}

Expand All @@ -76,6 +78,20 @@ impl<E: Dtype + SafeDtype, D: Device<E>> TensorVisitor<E, D> for Writer {
self.add(full_path, t);
Ok(None)
}

fn visit_scalar<N: num_traits::NumCast>(
&mut self,
_: ScalarOptions<N>,
(n, full_path): (&N, String),
) -> Result<Option<N>, Self::Err> {
let data = TensorData {
dtype: safetensors::Dtype::F64,
shape: Vec::new(),
data: n.to_f64().unwrap().to_le_bytes().to_vec(),
};
self.tensors.push((full_path, data));
Ok(None)
}
}

/// Something that can be saved to a `.safetensors`.
Expand Down
4 changes: 2 additions & 2 deletions src/tensor/numpy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ pub(crate) fn read_from_npz<R: Read + Seek, E: Dtype + NumpyDtype>(
read_from_npy(&mut f, shape)
}

pub(crate) fn write_to_npy<W: Write, E: Dtype + NumpyDtype>(
fn write_to_npy<W: Write, E: Dtype + NumpyDtype>(
w: &mut W,
shape: &[usize],
data: &[E],
Expand All @@ -98,7 +98,7 @@ pub(crate) fn write_to_npy<W: Write, E: Dtype + NumpyDtype>(
Ok(())
}

pub(crate) fn read_from_npy<R: Read, E: Dtype + NumpyDtype>(
fn read_from_npy<R: Read, E: Dtype + NumpyDtype>(
r: &mut R,
shape: &[usize],
) -> Result<Vec<E>, NpyError> {
Expand Down

0 comments on commit 95c3d05

Please sign in to comment.