Skip to content

Commit

Permalink
Update safetensors module and naming
Browse files Browse the repository at this point in the history
- Makes the safetensors module private.
  - Doesn't get exported on the preamble, avoiding a naming clash with the safetensors external crate.
- Change how and when the period is inserted.
  - This should make it closer to how the fields are accessed in the code.
  • Loading branch information
swfsql committed Feb 1, 2024
1 parent 4722a99 commit a832f51
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 9 deletions.
4 changes: 2 additions & 2 deletions dfdx-core/src/nn_traits/tuples.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ macro_rules! tuple_impls {
location: &str,
tensors: &mut Vec<(String, safetensors::Dtype, Vec<usize>, Vec<u8>)>,
) {
$(self.$idx.write_safetensors(&format!("{location}{}.", $idx), tensors);)+
$(self.$idx.write_safetensors(&format!("{location}.{}", $idx), tensors);)+
}
}

Expand All @@ -36,7 +36,7 @@ macro_rules! tuple_impls {
location: &str,
tensors: &safetensors::SafeTensors,
) -> Result<(), safetensors::SafeTensorError> {
$(self.$idx.read_safetensors(&format!("{location}{}.", $idx), tensors)?;)+
$(self.$idx.read_safetensors(&format!("{location}.{}", $idx), tensors)?;)+
Ok(())
}
}
Expand Down
4 changes: 2 additions & 2 deletions dfdx-core/src/nn_traits/vecs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ impl<T: crate::nn_traits::SaveSafeTensors> crate::nn_traits::SaveSafeTensors for
tensors: &mut Vec<(String, safetensors::Dtype, Vec<usize>, Vec<u8>)>,
) {
for (i, t) in self.iter().enumerate() {
t.write_safetensors(&format!("{location}{i}."), tensors);
t.write_safetensors(&format!("{location}.{i}"), tensors);
}
}
}
Expand All @@ -79,7 +79,7 @@ impl<T: crate::nn_traits::LoadSafeTensors> crate::nn_traits::LoadSafeTensors for
tensors: &safetensors::SafeTensors,
) -> Result<(), safetensors::SafeTensorError> {
for (i, t) in self.iter_mut().enumerate() {
t.read_safetensors(&format!("{location}{i}."), tensors)?;
t.read_safetensors(&format!("{location}.{i}"), tensors)?;
}
Ok(())
}
Expand Down
2 changes: 1 addition & 1 deletion dfdx-core/src/tensor/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ pub(crate) mod webgpu;
pub use numpy::NumpyDtype;
mod error;
#[cfg(feature = "safetensors")]
pub mod safetensors;
mod safetensors;
mod tensorlike;
mod unique_id;

Expand Down
20 changes: 16 additions & 4 deletions dfdx-derives/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -850,7 +850,10 @@ pub fn save_safetensors(input: proc_macro::TokenStream) -> proc_macro::TokenStre
where_clause
.predicates
.push(parse_quote!(#ty: ::dfdx::nn_traits::SaveSafeTensors));
quote_spanned!(f.span()=>self.#name.write_safetensors(&format!("{location}{}", #name_str), tensors);)
quote_spanned!(f.span()=>self.#name.write_safetensors(
&format!("{location}{}{}", if location.is_empty() { "" } else { "." }, #name_str),
tensors
);)
} else {
Default::default()
}
Expand All @@ -866,7 +869,10 @@ pub fn save_safetensors(input: proc_macro::TokenStream) -> proc_macro::TokenStre
where_clause
.predicates
.push(parse_quote!(#ty: ::dfdx::nn_traits::SaveSafeTensors));
quote_spanned!(f.span()=>self.#index.write_safetensors(&format!("{location}{}", #index), tensors);)
quote_spanned!(f.span()=>self.#index.write_safetensors(
&format!("{location}{}{}", if location.is_empty() { "" } else { "." }, #index),
tensors
);)
} else {
Default::default()
}
Expand Down Expand Up @@ -913,7 +919,10 @@ pub fn load_safetensors(input: proc_macro::TokenStream) -> proc_macro::TokenStre
where_clause
.predicates
.push(parse_quote!(#ty: ::dfdx::nn_traits::LoadSafeTensors));
quote_spanned!(f.span()=>self.#name.read_safetensors(&format!("{location}{}", #name_str), tensors)?;)
quote_spanned!(f.span()=>self.#name.read_safetensors(
&format!("{location}{}{}", if location.is_empty() { "" } else { "." }, #name_str),
tensors
)?;)
} else {
Default::default()
}
Expand All @@ -928,7 +937,10 @@ pub fn load_safetensors(input: proc_macro::TokenStream) -> proc_macro::TokenStre
where_clause
.predicates
.push(parse_quote!(#ty: ::dfdx::nn_traits::LoadSafeTensors));
quote_spanned!(f.span()=>self.#index.read_safetensors(&format!("{location}{}", #index), tensors)?;)
quote_spanned!(f.span()=>self.#index.read_safetensors(
&format!("{location}{}{}", if location.is_empty() { "" } else { "." }, #index),
tensors
)?;)
} else {
Default::default()
}
Expand Down

0 comments on commit a832f51

Please sign in to comment.