Skip to content

Commit

Permalink
allow to load safetensors from a byte array
Browse files Browse the repository at this point in the history
  • Loading branch information
swfsql committed Feb 20, 2024
1 parent e883b28 commit 45226f5
Showing 1 changed file with 15 additions and 0 deletions.
15 changes: 15 additions & 0 deletions dfdx-core/src/nn_traits/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,21 @@ pub trait LoadSafeTensors {
) -> Result<(), safetensors::SafeTensorError> {
self.load_safetensors_with(path, false, &mut core::convert::identity)
}
fn load_safetensors_from_bytes_with<F: FnMut(String) -> String>(
&mut self,
bytes: &[u8],
skip_missing: bool,
key_map: &mut F,
) -> Result<(), safetensors::SafeTensorError> {
let tensors = safetensors::SafeTensors::deserialize(&bytes)?;
self.read_safetensors_with("", &tensors, skip_missing, key_map)
}
fn load_safetensors_from_bytes(
&mut self,
bytes: &[u8],
) -> Result<(), safetensors::SafeTensorError> {
self.load_safetensors_from_bytes_with(bytes, false, &mut core::convert::identity)
}

fn read_safetensors_with<F: FnMut(String) -> String>(
&mut self,
Expand Down

0 comments on commit 45226f5

Please sign in to comment.