Skip to content

Commit

Permalink
Updating safetensors depe ndency
Browse files Browse the repository at this point in the history
  • Loading branch information
coreylowman committed Sep 5, 2023
1 parent 1255690 commit 761c2eb
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 12 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ resolver = "2"

[workspace.dependencies]
num-traits = { version = "0.2.15", default-features = false }
safetensors = { version = "0.3", default-features = false }
safetensors = { version = "0.3.3", default-features = false }
memmap2 = { version = "0.5", default-features = false }
rand = { version = "0.8.5", default-features = false, features = ["std_rng"] }
rand_distr = { version = "0.4.3", default-features = false }
Expand Down
17 changes: 6 additions & 11 deletions dfdx-nn-core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -150,17 +150,12 @@ pub trait SaveSafeTensors {
) -> Result<(), safetensors::SafeTensorError> {
let mut tensors = Vec::new();
self.write_safetensors("", &mut tensors);
let data = tensors
.iter()
.map(|(k, dtype, shape, data)| {
(
k.clone(),
safetensors::tensor::TensorView::new(dtype.clone(), shape.clone(), data)
.unwrap(),
)
})
.collect::<Vec<_>>();
let data = data.iter().map(|i| (i.0.clone(), &i.1)).collect::<Vec<_>>();
let data = tensors.iter().map(|(k, dtype, shape, data)| {
(
k.clone(),
safetensors::tensor::TensorView::new(dtype.clone(), shape.clone(), data).unwrap(),

Check failure on line 156 in dfdx-nn-core/src/lib.rs

View workflow job for this annotation

GitHub Actions / cargo-check

using `clone` on type `Dtype` which implements the `Copy` trait
)
});

safetensors::serialize_to_file(data, &None, path.as_ref())
}
Expand Down

0 comments on commit 761c2eb

Please sign in to comment.