Skip to content

Commit

Permalink
implement npz scalar serialization
Browse files Browse the repository at this point in the history
  • Loading branch information
nkoppel committed Jun 29, 2023
1 parent d2dd8ba commit 1b51fab
Show file tree
Hide file tree
Showing 2 changed files with 97 additions and 22 deletions.
30 changes: 29 additions & 1 deletion src/nn/npz.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::{
shapes::{Dtype, Shape},
tensor::{
numpy::{NpzError, NumpyDtype},
numpy::{read_from_npz, write_to_npz, NpzError, NumpyDtype},
Tensor,
},
tensor_ops::Device,
Expand Down Expand Up @@ -129,6 +129,18 @@ impl<W: Write + Seek, E: Dtype + NumpyDtype, D: Device<E>> TensorVisitor<E, D>
t.write_to_npz(self, full_path)?;
Ok(None)
}

fn visit_scalar<N: num_traits::NumCast>(
&mut self,
_opts: ScalarOptions<N>,
(n, full_path): (&N, String),
) -> Result<Option<N>, Self::Err> {
let n = n
.to_f64()
.unwrap_or_else(|| panic!("Failed to convert scalar value at {full_path} to f64!"));
write_to_npz(self, &[], &[n], full_path)?;
Ok(None)
}
}

impl<R: Read + Seek, E: Dtype + NumpyDtype, D: Device<E>> TensorVisitor<E, D>
Expand All @@ -147,6 +159,22 @@ impl<R: Read + Seek, E: Dtype + NumpyDtype, D: Device<E>> TensorVisitor<E, D>
t.read_from_npz(self, full_path)?;
Ok(None)
}

fn visit_scalar<N: num_traits::NumCast>(
&mut self,
_opts: ScalarOptions<N>,
(n, full_path): (&mut N, String),
) -> Result<Option<N>, Self::Err> {
let buf: Vec<f64> = read_from_npz(self, &[], full_path)?;
*n = N::from(buf[0]).unwrap_or_else(|| {
panic!(
"Failed to convert f64 value {} to {}!",
buf[0],
std::any::type_name::<N>()
)
});
Ok(None)
}
}

#[cfg(test)]
Expand Down
89 changes: 68 additions & 21 deletions src/tensor/numpy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,34 +58,81 @@ impl<S: Shape, E: Dtype + NumpyDtype, D: CopySlice<E>, T> Tensor<S, E, D, T> {
self.write_to(&mut f)
}

pub(crate) fn read_from<R: Read>(&mut self, r: &mut R) -> Result<(), NpyError> {
let endian = read_header::<R, E>(r, self.shape().concrete().into_iter().collect())?;
let numel = self.shape().num_elements();
let mut buf = Vec::with_capacity(numel);
for _ in 0..numel {
buf.push(E::read_endian(r, endian)?);
}
D::copy_from(self, &buf);
fn read_from<R: Read>(&mut self, r: &mut R) -> Result<(), NpyError> {
let buf = read_from_npy(r, self.shape().concrete().as_ref())?;
self.copy_from(&buf);
Ok(())
}

pub(crate) fn write_to<W: Write>(&self, w: &mut W) -> io::Result<()> {
let endian = Endian::Little;
write_header::<W, E>(w, endian, self.shape().concrete().into_iter().collect())?;
let numel = self.shape().num_elements();
let mut buf = std::vec![Default::default(); numel];
D::copy_into(self, &mut buf);
for v in buf.iter() {
v.write_endian(w, endian)?;
}
Ok(())
fn write_to<W: Write>(&self, w: &mut W) -> io::Result<()> {
let buf = self.as_vec();
write_to_npy(w, self.shape().concrete().as_ref(), &buf)
}
}

pub(crate) fn write_to_npz<W: Write + Seek, E: Dtype + NumpyDtype>(
w: &mut zip::ZipWriter<W>,
shape: &[usize],
data: &[E],
mut filename: String,
) -> io::Result<()> {
if !filename.ends_with(".npy") {
filename.push_str(".npy");
}
w.start_file(filename, Default::default())?;
write_to_npy(w, shape, data)
}

pub(crate) fn read_from_npz<R: Read + Seek, E: Dtype + NumpyDtype>(
r: &mut zip::ZipArchive<R>,
shape: &[usize],
mut filename: String,
) -> Result<Vec<E>, NpyError> {
if !filename.ends_with(".npy") {
filename.push_str(".npy");
}

let mut f = r
.by_name(&filename)
.unwrap_or_else(|_| panic!("'{filename}' not found"));

read_from_npy(&mut f, shape)
}

pub(crate) fn write_to_npy<W: Write, E: Dtype + NumpyDtype>(
w: &mut W,
shape: &[usize],
data: &[E],
) -> io::Result<()> {
let endian = Endian::Little;
write_header::<W, E>(w, endian, shape)?;

for v in data.iter() {
v.write_endian(w, endian)?;
}

Ok(())
}

pub(crate) fn read_from_npy<R: Read, E: Dtype + NumpyDtype>(
r: &mut R,
shape: &[usize],
) -> Result<Vec<E>, NpyError> {
let endian = read_header::<R, E>(r, shape)?;
let numel = shape.iter().product::<usize>();
let mut out = Vec::new();

for _ in 0..numel {
out.push(E::read_endian(r, endian)?);
}

Ok(out)
}

fn write_header<W: Write, E: NumpyDtype>(
w: &mut W,
endian: Endian,
shape: Vec<usize>,
shape: &[usize],
) -> io::Result<()> {
let shape_str = to_shape_str(shape);

Expand Down Expand Up @@ -121,7 +168,7 @@ fn write_header<W: Write, E: NumpyDtype>(
Ok(())
}

fn read_header<R: Read, E: NumpyDtype>(r: &mut R, shape: Vec<usize>) -> Result<Endian, NpyError> {
fn read_header<R: Read, E: NumpyDtype>(r: &mut R, shape: &[usize]) -> Result<Endian, NpyError> {
let mut magic = [0; 6];
r.read_exact(&mut magic)?;
if magic != MAGIC_NUMBER {
Expand Down Expand Up @@ -311,7 +358,7 @@ impl From<std::string::FromUtf8Error> for NpyError {
}
}

fn to_shape_str(shape: Vec<usize>) -> String {
fn to_shape_str(shape: &[usize]) -> String {
shape
.iter()
.map(|v| v.to_string())
Expand Down

0 comments on commit 1b51fab

Please sign in to comment.