From d845e86155f62db88dfaef549bef905860b390bf Mon Sep 17 00:00:00 2001 From: Nathan Koppel Date: Tue, 11 Jul 2023 11:29:08 -0500 Subject: [PATCH] allow models without scalars to be read without errors --- src/nn/npz.rs | 29 +++++++++++++++++++---------- src/nn/safetensors.rs | 33 +++++++++++++++++++++------------ src/tensor/numpy.rs | 13 ++++++++++--- 3 files changed, 50 insertions(+), 25 deletions(-) diff --git a/src/nn/npz.rs b/src/nn/npz.rs index ae0cb574c..3b3faa9ae 100644 --- a/src/nn/npz.rs +++ b/src/nn/npz.rs @@ -1,4 +1,5 @@ use crate::{ + prelude::numpy::NpyError, shapes::{Dtype, Shape}, tensor::{ numpy::{read_from_npz, write_to_npz, NpzError, NumpyDtype}, @@ -162,18 +163,26 @@ impl> TensorVisitor fn visit_scalar( &mut self, - _opts: ScalarOptions, + opts: ScalarOptions, (n, full_path): (&mut N, String), ) -> Result, Self::Err> { - let buf: Vec = read_from_npz(self, &[], full_path)?; - *n = N::from(buf[0]).unwrap_or_else(|| { - panic!( - "Failed to convert f64 value {} to {} when reading from npz!", - buf[0], - std::any::type_name::() - ) - }); - Ok(None) + match read_from_npz::<_, f64>(self, &[], full_path) { + Ok(buf) => { + *n = N::from(buf[0]).unwrap_or_else(|| { + panic!( + "Failed to convert f64 value {} to {} when reading from npz!", + buf[0], + std::any::type_name::() + ) + }); + Ok(None) + } + Err(NpyError::IoError(e)) if e.kind() == std::io::ErrorKind::NotFound => { + *n = opts.default; + Ok(None) + } + Err(x) => Err(x.into()), + } } } diff --git a/src/nn/safetensors.rs b/src/nn/safetensors.rs index 46f04a80d..5055dd448 100644 --- a/src/nn/safetensors.rs +++ b/src/nn/safetensors.rs @@ -171,20 +171,29 @@ impl<'data, E: Dtype + SafeDtype, D: Device> TensorVisitor for SafeTens fn visit_scalar( &mut self, - _: ScalarOptions, + opts: ScalarOptions, (n, full_path): (&mut N, String), ) -> Result, Self::Err> { - let data = self.tensor(&full_path)?.data(); - let mut array = [0; 8]; - array.copy_from_slice(data); - let val = f64::from_le_bytes(array); - *n = N::from(val).unwrap_or_else(|| { - panic!( - "Failed to convert f64 value {val} at {full_path} to {} when reading from safetensors!", - std::any::type_name::() - ) - }); - Ok(None) + match self.tensor(&full_path) { + Ok(tensor) => { + let data = tensor.data(); + let mut array = [0; 8]; + array.copy_from_slice(data); + let val = f64::from_le_bytes(array); + *n = N::from(val).unwrap_or_else(|| { + panic!( + "Failed to convert f64 value {val} at {full_path} to {} when reading from safetensors!", + std::any::type_name::() + ) + }); + Ok(None) + } + Err(SafeTensorError::TensorNotFound(_)) => { + *n = opts.default; + Ok(None) + } + Err(x) => Err(Error::SafeTensorError(x)), + } } } diff --git a/src/tensor/numpy.rs b/src/tensor/numpy.rs index a75404f3d..b362be2ad 100644 --- a/src/tensor/numpy.rs +++ b/src/tensor/numpy.rs @@ -76,9 +76,16 @@ pub(crate) fn read_from_npz( filename.push_str(".npy"); } - let mut f = r - .by_name(&filename) - .unwrap_or_else(|_| panic!("'{filename}' not found")); + let mut f = match r.by_name(&filename) { + Ok(f) => f, + Err(ZipError::FileNotFound) => { + return Err(NpyError::IoError(io::Error::new( + io::ErrorKind::NotFound, + ZipError::FileNotFound, + ))) + } + Err(e) => panic!("Uncaught zip error: {e}"), + }; read_from_npy(&mut f, shape) }