From a832f51acf89ceafa161c87c6de8bb7e6cb56296 Mon Sep 17 00:00:00 2001 From: Thiago Machado Date: Wed, 31 Jan 2024 01:01:29 -0500 Subject: [PATCH] Update safetensors module and naming - Makes the safetensors module private. - Doesn't get exported on the preamble, avoiding a naming clash with the safetensors external crate. - Change how and when the period is inserted. - This should make it closer to how the fields are accessed in the code. --- dfdx-core/src/nn_traits/tuples.rs | 4 ++-- dfdx-core/src/nn_traits/vecs.rs | 4 ++-- dfdx-core/src/tensor/mod.rs | 2 +- dfdx-derives/src/lib.rs | 20 ++++++++++++++++---- 4 files changed, 21 insertions(+), 9 deletions(-) diff --git a/dfdx-core/src/nn_traits/tuples.rs b/dfdx-core/src/nn_traits/tuples.rs index 97e8c7de..205c0419 100644 --- a/dfdx-core/src/nn_traits/tuples.rs +++ b/dfdx-core/src/nn_traits/tuples.rs @@ -25,7 +25,7 @@ macro_rules! tuple_impls { location: &str, tensors: &mut Vec<(String, safetensors::Dtype, Vec, Vec)>, ) { - $(self.$idx.write_safetensors(&format!("{location}{}.", $idx), tensors);)+ + $(self.$idx.write_safetensors(&format!("{location}.{}", $idx), tensors);)+ } } @@ -36,7 +36,7 @@ macro_rules! tuple_impls { location: &str, tensors: &safetensors::SafeTensors, ) -> Result<(), safetensors::SafeTensorError> { - $(self.$idx.read_safetensors(&format!("{location}{}.", $idx), tensors)?;)+ + $(self.$idx.read_safetensors(&format!("{location}.{}", $idx), tensors)?;)+ Ok(()) } } diff --git a/dfdx-core/src/nn_traits/vecs.rs b/dfdx-core/src/nn_traits/vecs.rs index 803a07d8..593b1a55 100644 --- a/dfdx-core/src/nn_traits/vecs.rs +++ b/dfdx-core/src/nn_traits/vecs.rs @@ -66,7 +66,7 @@ impl crate::nn_traits::SaveSafeTensors for tensors: &mut Vec<(String, safetensors::Dtype, Vec, Vec)>, ) { for (i, t) in self.iter().enumerate() { - t.write_safetensors(&format!("{location}{i}."), tensors); + t.write_safetensors(&format!("{location}.{i}"), tensors); } } } @@ -79,7 +79,7 @@ impl crate::nn_traits::LoadSafeTensors for tensors: &safetensors::SafeTensors, ) -> Result<(), safetensors::SafeTensorError> { for (i, t) in self.iter_mut().enumerate() { - t.read_safetensors(&format!("{location}{i}."), tensors)?; + t.read_safetensors(&format!("{location}.{i}"), tensors)?; } Ok(()) } diff --git a/dfdx-core/src/tensor/mod.rs b/dfdx-core/src/tensor/mod.rs index acc4074a..0163480a 100644 --- a/dfdx-core/src/tensor/mod.rs +++ b/dfdx-core/src/tensor/mod.rs @@ -151,7 +151,7 @@ pub(crate) mod webgpu; pub use numpy::NumpyDtype; mod error; #[cfg(feature = "safetensors")] -pub mod safetensors; +mod safetensors; mod tensorlike; mod unique_id; diff --git a/dfdx-derives/src/lib.rs b/dfdx-derives/src/lib.rs index 4eca0d82..7af885f9 100644 --- a/dfdx-derives/src/lib.rs +++ b/dfdx-derives/src/lib.rs @@ -850,7 +850,10 @@ pub fn save_safetensors(input: proc_macro::TokenStream) -> proc_macro::TokenStre where_clause .predicates .push(parse_quote!(#ty: ::dfdx::nn_traits::SaveSafeTensors)); - quote_spanned!(f.span()=>self.#name.write_safetensors(&format!("{location}{}", #name_str), tensors);) + quote_spanned!(f.span()=>self.#name.write_safetensors( + &format!("{location}{}{}", if location.is_empty() { "" } else { "." }, #name_str), + tensors + );) } else { Default::default() } @@ -866,7 +869,10 @@ pub fn save_safetensors(input: proc_macro::TokenStream) -> proc_macro::TokenStre where_clause .predicates .push(parse_quote!(#ty: ::dfdx::nn_traits::SaveSafeTensors)); - quote_spanned!(f.span()=>self.#index.write_safetensors(&format!("{location}{}", #index), tensors);) + quote_spanned!(f.span()=>self.#index.write_safetensors( + &format!("{location}{}{}", if location.is_empty() { "" } else { "." }, #index), + tensors + );) } else { Default::default() } @@ -913,7 +919,10 @@ pub fn load_safetensors(input: proc_macro::TokenStream) -> proc_macro::TokenStre where_clause .predicates .push(parse_quote!(#ty: ::dfdx::nn_traits::LoadSafeTensors)); - quote_spanned!(f.span()=>self.#name.read_safetensors(&format!("{location}{}", #name_str), tensors)?;) + quote_spanned!(f.span()=>self.#name.read_safetensors( + &format!("{location}{}{}", if location.is_empty() { "" } else { "." }, #name_str), + tensors + )?;) } else { Default::default() } @@ -928,7 +937,10 @@ pub fn load_safetensors(input: proc_macro::TokenStream) -> proc_macro::TokenStre where_clause .predicates .push(parse_quote!(#ty: ::dfdx::nn_traits::LoadSafeTensors)); - quote_spanned!(f.span()=>self.#index.read_safetensors(&format!("{location}{}", #index), tensors)?;) + quote_spanned!(f.span()=>self.#index.read_safetensors( + &format!("{location}{}{}", if location.is_empty() { "" } else { "." }, #index), + tensors + )?;) } else { Default::default() }