Skip to content

Commit

Permalink
Disallow feature detection on generated code
Browse files Browse the repository at this point in the history
- Move feature detection to before generating code.
- Only emits the `SaveSafeTensors` and `LoadSafeTensors` derivations and the `#[serialize]` attr if the `safetensors` feature is enabled.
  • Loading branch information
swfsql committed Nov 8, 2023
1 parent 4476b5e commit b1a6e67
Show file tree
Hide file tree
Showing 20 changed files with 135 additions and 91 deletions.
87 changes: 62 additions & 25 deletions dfdx-derives/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,12 @@ pub fn custom_module(input: proc_macro::TokenStream) -> proc_macro::TokenStream
where_clause
.predicates
.push(parse_quote!(#ty: dfdx_core::nn_traits::BuildOnDevice<Elem, Dev>));
quote_spanned!(f.span()=> #[module] #[serialize] #vis #name: <#ty as dfdx_core::nn_traits::BuildOnDevice<Elem, Dev>>::Built,)
let safetensors_serialize_attr = if cfg!(features = "safetensors") {
quote!(#[serialize])
} else {
quote!()
};
quote_spanned!(f.span()=> #[module] #safetensors_serialize_attr #vis #name: <#ty as dfdx_core::nn_traits::BuildOnDevice<Elem, Dev>>::Built,)
} else {
quote_spanned!(f.span()=> #vis #name: #ty,)
}
Expand All @@ -126,7 +131,12 @@ pub fn custom_module(input: proc_macro::TokenStream) -> proc_macro::TokenStream
where_clause
.predicates
.push(parse_quote!(#ty: dfdx_core::nn_traits::BuildOnDevice<Elem, Dev>));
quote_spanned!(f.span()=> #[module] #[serialize] #vis <#ty as dfdx_core::nn_traits::BuildOnDevice<Elem, Dev>>::Built,)
let safetensors_serialize_attr = if cfg!(features = "safetensors") {
quote!(#[serialize])
} else {
quote!()
};
quote_spanned!(f.span()=> #[module] #safetensors_serialize_attr #vis <#ty as dfdx_core::nn_traits::BuildOnDevice<Elem, Dev>>::Built,)
} else {
quote_spanned!(f.span()=> #vis #ty,)
}
Expand Down Expand Up @@ -162,8 +172,13 @@ pub fn custom_module(input: proc_macro::TokenStream) -> proc_macro::TokenStream
let (built_impl, _, built_where) = built_generics.split_for_impl();

let def = if has_fields_to_build {
let safetensors_derive = if cfg!(feature = "safetensors") {
quote!(dfdx_derives::SaveSafeTensors, dfdx_derives::LoadSafeTensors)
} else {
quote!()
};
quote! {
#[derive(Clone, Debug, dfdx_derives::ResetParams, dfdx_derives::UpdateParams, dfdx_derives::ZeroGrads, dfdx_derives::SaveSafeTensors, dfdx_derives::LoadSafeTensors)]
#[derive(Clone, Debug, dfdx_derives::ResetParams, dfdx_derives::UpdateParams, dfdx_derives::ZeroGrads, #safetensors_derive)]
pub struct #built_name #built_impl #built_where #fields
}
} else {
Expand All @@ -181,26 +196,32 @@ pub fn custom_module(input: proc_macro::TokenStream) -> proc_macro::TokenStream
let (build_impl, _, _) = build_generics.split_for_impl();
let (built_impl, built_ty, built_where) = built_generics.split_for_impl();

quote! {
#[cfg(feature = "safetensors")]
impl #built_impl dfdx_core::nn_traits::SaveSafeTensors for #builder_name #built_ty #built_where {
fn write_safetensors(
&self,
location: &str,
tensors: &mut Vec<(String, ::safetensors::Dtype, Vec<usize>, Vec<u8>)>,
) {}
}
let safetensors_impls = if cfg!(feature = "safetensors") {
quote! {
impl #built_impl dfdx_core::nn_traits::SaveSafeTensors for #builder_name #built_ty #built_where {
fn write_safetensors(
&self,
location: &str,
tensors: &mut Vec<(String, ::safetensors::Dtype, Vec<usize>, Vec<u8>)>,
) {}
}

#[cfg(feature = "safetensors")]
impl #built_impl dfdx_core::nn_traits::LoadSafeTensors for #builder_name #built_ty #built_where {
fn read_safetensors<'a>(
&mut self,
location: &str,
tensors: &::safetensors::SafeTensors<'a>,
) -> Result<(), ::safetensors::SafeTensorError> {
Ok(())
impl #built_impl dfdx_core::nn_traits::LoadSafeTensors for #builder_name #built_ty #built_where {
fn read_safetensors<'a>(
&mut self,
location: &str,
tensors: &::safetensors::SafeTensors<'a>,
) -> Result<(), ::safetensors::SafeTensorError> {
Ok(())
}
}
}
} else {
quote! {}
};

quote! {
#safetensors_impls

impl #build_impl dfdx_core::nn_traits::ResetParams<Elem, Dev> for #builder_name #built_ty #built_where {
fn try_reset_params(&mut self) -> Result<(), dfdx_core::tensor::Error> {
Expand Down Expand Up @@ -373,7 +394,12 @@ pub fn sequential(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
where_clause
.predicates
.push(parse_quote!(#ty: dfdx_core::nn_traits::BuildOnDevice<Elem, Dev>));
quote_spanned!(f.span()=> #[module] #[serialize] #vis #name: <#ty as dfdx_core::nn_traits::BuildOnDevice<Elem, Dev>>::Built,)
let safetensors_serialize_attr = if cfg!(features = "safetensors") {
quote!(#[serialize])
} else {
quote!()
};
quote_spanned!(f.span()=> #[module] #safetensors_serialize_attr #vis #name: <#ty as dfdx_core::nn_traits::BuildOnDevice<Elem, Dev>>::Built,)
});
quote! { #(#fields)* }
}
Expand All @@ -384,7 +410,12 @@ pub fn sequential(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
where_clause
.predicates
.push(parse_quote!(#ty: dfdx_core::nn_traits::BuildOnDevice<Elem, Dev>));
quote_spanned!(f.span()=> #[module] #[serialize] #vis <#ty as dfdx_core::nn_traits::BuildOnDevice<Elem, Dev>>::Built,)
let safetensors_serialize_attr = if cfg!(features = "safetensors") {
quote!(#[serialize])
} else {
quote!()
};
quote_spanned!(f.span()=> #[module] #safetensors_serialize_attr #vis <#ty as dfdx_core::nn_traits::BuildOnDevice<Elem, Dev>>::Built,)
});
quote! { #(#fields)* }
}
Expand All @@ -397,8 +428,14 @@ pub fn sequential(input: proc_macro::TokenStream) -> proc_macro::TokenStream {

let (built_impl, _, built_where) = built_generics.split_for_impl();

let safetensors_derive = if cfg!(feature = "safetensors") {
quote!(dfdx_derives::SaveSafeTensors, dfdx_derives::LoadSafeTensors)
} else {
quote!()
};

quote! {
#[derive(Clone, Debug, dfdx_derives::ResetParams, dfdx_derives::UpdateParams, dfdx_derives::ZeroGrads, dfdx_derives::SaveSafeTensors, dfdx_derives::LoadSafeTensors)]
#[derive(Clone, Debug, dfdx_derives::ResetParams, dfdx_derives::UpdateParams, dfdx_derives::ZeroGrads, #safetensors_derive)]
pub struct #built_name #built_impl #built_where {
#fields
}
Expand Down Expand Up @@ -849,7 +886,7 @@ pub fn save_safetensors(input: proc_macro::TokenStream) -> proc_macro::TokenStre
let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();

proc_macro::TokenStream::from(quote! {
#[cfg(feature = "safetensors")]
// note: SaveSafeTensors definition is already gated by the safetensors feature
impl #impl_generics dfdx_core::nn_traits::SaveSafeTensors for #name #ty_generics #where_clause {
fn write_safetensors(
&self,
Expand Down Expand Up @@ -911,7 +948,7 @@ pub fn load_safetensors(input: proc_macro::TokenStream) -> proc_macro::TokenStre
let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();

proc_macro::TokenStream::from(quote! {
#[cfg(feature = "safetensors")]
// note: LoadSafeTensors definition is already gated by the safetensors feature
impl #impl_generics dfdx_core::nn_traits::LoadSafeTensors for #name #ty_generics #where_clause {
fn read_safetensors<'a>(
&mut self,
Expand Down
9 changes: 4 additions & 5 deletions dfdx/src/nn/layers/add_into.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::prelude::*;

/// Add inputs together into a single tensor. `T` should be a tuple
//// where every element of the tuple has the same output type
/// where every element of the tuple has the same output type
///
/// This provides a utility for networks where multiple inputs are needed
///
Expand All @@ -19,13 +19,12 @@ use crate::prelude::*;
/// let b: Tensor<Rank1<3>, f32, _> = dev.zeros();
/// let _: Tensor<Rank1<5>, f32, _> = model.forward((a, b));
/// ```
#[derive(
Debug, Default, Clone, ResetParams, ZeroGrads, UpdateParams, LoadSafeTensors, SaveSafeTensors,
)]
#[derive(Debug, Default, Clone, ResetParams, ZeroGrads, UpdateParams)]
#[cfg_attr(feature = "safetensors", derive(SaveSafeTensors, LoadSafeTensors))]
#[repr(transparent)]
pub struct AddInto<T>(
#[module]
#[serialize]
#[cfg_attr(feature = "safetensors", serialize)]
pub T,
);

Expand Down
15 changes: 8 additions & 7 deletions dfdx/src/nn/layers/batch_norm1d.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,29 +55,30 @@ impl<C: Dim, E: Dtype, D: Device<E>> BuildOnDevice<E, D> for BatchNorm1DConfig<C
}

/// See [BatchNorm1DConfig].
#[derive(Clone, Debug, UpdateParams, ZeroGrads, SaveSafeTensors, LoadSafeTensors)]
#[derive(Clone, Debug, UpdateParams, ZeroGrads)]
#[cfg_attr(feature = "safetensors", derive(SaveSafeTensors, LoadSafeTensors))]
pub struct BatchNorm1D<C: Dim, Elem: Dtype, Dev: Device<Elem>> {
/// Scale for affine transform. Defaults to 1.0
#[param]
#[serialize]
#[cfg_attr(feature = "safetensors", serialize)]
pub scale: Tensor<(C,), Elem, Dev>,
/// Bias for affine transform. Defaults to 0.0
#[param]
#[serialize]
#[cfg_attr(feature = "safetensors", serialize)]
pub bias: Tensor<(C,), Elem, Dev>,
/// Spatial mean that is updated during training. Defaults to 0.0
#[serialize]
#[cfg_attr(feature = "safetensors", serialize)]
pub running_mean: Tensor<(C,), Elem, Dev>,
/// Spatial variance that is updated during training. Defaults to 1.0
#[serialize]
#[cfg_attr(feature = "safetensors", serialize)]
pub running_var: Tensor<(C,), Elem, Dev>,
/// Added to variance before taking sqrt for numerical stability. Defaults to 1e-5
#[serialize]
#[cfg_attr(feature = "safetensors", serialize)]
pub epsilon: f64,
/// Controls exponential moving average of running stats. Defaults to 0.1
///
/// `running_stat * (1.0 - momentum) + stat * momentum`.
#[serialize]
#[cfg_attr(feature = "safetensors", serialize)]
pub momentum: f64,
}

Expand Down
15 changes: 8 additions & 7 deletions dfdx/src/nn/layers/batch_norm2d.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,21 +57,22 @@ impl<C: Dim, E: Dtype, D: Device<E>> crate::nn::BuildOnDevice<E, D> for BatchNor
}

/// See [BatchNorm2DConfig]
#[derive(Clone, Debug, UpdateParams, ZeroGrads, SaveSafeTensors, LoadSafeTensors)]
#[derive(Clone, Debug, UpdateParams, ZeroGrads)]
#[cfg_attr(feature = "safetensors", derive(SaveSafeTensors, LoadSafeTensors))]
pub struct BatchNorm2D<C: Dim, Elem: Dtype, Dev: Device<Elem>> {
#[param]
#[serialize]
#[cfg_attr(feature = "safetensors", serialize)]
pub scale: Tensor<(C,), Elem, Dev>,
#[param]
#[serialize]
#[cfg_attr(feature = "safetensors", serialize)]
pub bias: Tensor<(C,), Elem, Dev>,
#[serialize]
#[cfg_attr(feature = "safetensors", serialize)]
pub running_mean: Tensor<(C,), Elem, Dev>,
#[serialize]
#[cfg_attr(feature = "safetensors", serialize)]
pub running_var: Tensor<(C,), Elem, Dev>,
#[serialize]
#[cfg_attr(feature = "safetensors", serialize)]
pub epsilon: f64,
#[serialize]
#[cfg_attr(feature = "safetensors", serialize)]
pub momentum: f64,
}

Expand Down
5 changes: 3 additions & 2 deletions dfdx/src/nn/layers/bias1d.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,11 @@ impl<I: Dim, E: Dtype, D: Device<E>> BuildOnDevice<E, D> for Bias1DConfig<I> {
}

/// See [Bias1DConfig]
#[derive(Clone, Debug, UpdateParams, ZeroGrads, SaveSafeTensors, LoadSafeTensors)]
#[derive(Clone, Debug, UpdateParams, ZeroGrads)]
#[cfg_attr(feature = "safetensors", derive(SaveSafeTensors, LoadSafeTensors))]
pub struct Bias1D<I: Dim, Elem: Dtype, Dev: Device<Elem>> {
#[param]
#[serialize]
#[cfg_attr(feature = "safetensors", serialize)]
pub bias: Tensor<(I,), Elem, Dev>,
}

Expand Down
5 changes: 3 additions & 2 deletions dfdx/src/nn/layers/bias2d.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,11 @@ impl<C: Dim, E: Dtype, D: Device<E>> BuildOnDevice<E, D> for Bias2DConfig<C> {
}

/// See [Bias2DConfig]
#[derive(Clone, Debug, UpdateParams, ZeroGrads, SaveSafeTensors, LoadSafeTensors)]
#[derive(Clone, Debug, UpdateParams, ZeroGrads)]
#[cfg_attr(feature = "safetensors", derive(SaveSafeTensors, LoadSafeTensors))]
pub struct Bias2D<C: Dim, Elem: Dtype, Dev: Device<Elem>> {
#[param]
#[serialize]
#[cfg_attr(feature = "safetensors", serialize)]
pub bias: Tensor<(C,), Elem, Dev>,
}

Expand Down
5 changes: 3 additions & 2 deletions dfdx/src/nn/layers/conv1d.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,8 @@ where
}

/// The module built with [Conv1DConfig]. See [Conv1DConfig] for usage.
#[derive(Debug, Clone, UpdateParams, ZeroGrads, SaveSafeTensors, LoadSafeTensors)]
#[derive(Debug, Clone, UpdateParams, ZeroGrads)]
#[cfg_attr(feature = "safetensors", derive(SaveSafeTensors, LoadSafeTensors))]
pub struct Conv1D<InChan, OutChan, KernelSize, Stride, Padding, Dilation, Groups, Elem, Dev>
where
InChan: std::ops::Div<Groups>,
Expand All @@ -94,7 +95,7 @@ where
Dev: Device<Elem>,
{
#[param]
#[serialize]
#[cfg_attr(feature = "safetensors", serialize)]
#[allow(clippy::type_complexity)]
pub weight: Tensor<
(
Expand Down
5 changes: 3 additions & 2 deletions dfdx/src/nn/layers/conv2d.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,8 @@ where
}

/// The module built with [Conv2DConfig]. See [Conv2DConfig] for usage.
#[derive(Debug, Clone, UpdateParams, ZeroGrads, SaveSafeTensors, LoadSafeTensors)]
#[derive(Debug, Clone, UpdateParams, ZeroGrads)]
#[cfg_attr(feature = "safetensors", derive(SaveSafeTensors, LoadSafeTensors))]
pub struct Conv2D<InChan, OutChan, KernelSize, Stride, Padding, Dilation, Groups, Elem, Dev>
where
InChan: std::ops::Div<Groups>,
Expand All @@ -115,7 +116,7 @@ where
Dev: Device<Elem>,
{
#[param]
#[serialize]
#[cfg_attr(feature = "safetensors", serialize)]
#[allow(clippy::type_complexity)]
pub weight: Tensor<
(
Expand Down
5 changes: 3 additions & 2 deletions dfdx/src/nn/layers/conv_trans2d.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,8 @@ where
}

/// See [ConvTrans2DConfig].
#[derive(Debug, Clone, UpdateParams, ZeroGrads, SaveSafeTensors, LoadSafeTensors)]
#[derive(Debug, Clone, UpdateParams, ZeroGrads)]
#[cfg_attr(feature = "safetensors", derive(SaveSafeTensors, LoadSafeTensors))]
pub struct ConvTrans2D<InChan, OutChan, KernelSize, Stride, Padding, Dilation, Groups, Elem, Dev>
where
OutChan: std::ops::Div<Groups>,
Expand All @@ -93,7 +94,7 @@ where
Dev: Device<Elem>,
{
#[param]
#[serialize]
#[cfg_attr(feature = "safetensors", serialize)]
#[allow(clippy::type_complexity)]
pub weight: Tensor<
(
Expand Down
5 changes: 3 additions & 2 deletions dfdx/src/nn/layers/embedding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,11 @@ impl<V: Dim, M: Dim, E: Dtype, D: Device<E>> BuildOnDevice<E, D> for EmbeddingCo
}

/// See [EmbeddingConfig].
#[derive(Clone, Debug, UpdateParams, ZeroGrads, SaveSafeTensors, LoadSafeTensors)]
#[derive(Clone, Debug, UpdateParams, ZeroGrads)]
#[cfg_attr(feature = "safetensors", derive(SaveSafeTensors, LoadSafeTensors))]
pub struct Embedding<Vocab: Dim, Model: Dim, Elem: Dtype, Dev: Device<Elem>> {
#[param]
#[serialize]
#[cfg_attr(feature = "safetensors", serialize)]
pub weight: Tensor<(Vocab, Model), Elem, Dev>,
}

Expand Down
9 changes: 4 additions & 5 deletions dfdx/src/nn/layers/generalized_add.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,14 @@ use crate::prelude::*;
/// let y = model.forward(x);
/// assert_eq!(y.array(), [4.0, 1.0, 0.0, 2.0, 6.0]);
/// ```
#[derive(
Default, Clone, Debug, ResetParams, ZeroGrads, UpdateParams, LoadSafeTensors, SaveSafeTensors,
)]
#[derive(Default, Clone, Debug, ResetParams, ZeroGrads, UpdateParams)]
#[cfg_attr(feature = "safetensors", derive(SaveSafeTensors, LoadSafeTensors))]
pub struct GeneralizedAdd<T, U> {
#[module]
#[serialize]
#[cfg_attr(feature = "safetensors", serialize)]
pub t: T,
#[module]
#[serialize]
#[cfg_attr(feature = "safetensors", serialize)]
pub u: U,
}

Expand Down
9 changes: 4 additions & 5 deletions dfdx/src/nn/layers/generalized_mul.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,14 @@ use crate::prelude::*;
/// let y = model.forward(x);
/// assert_eq!(y.array(), [0.0, 0.0, 0.0, 1.0, 8.0]);
/// ```
#[derive(
Default, Clone, Debug, ResetParams, ZeroGrads, UpdateParams, LoadSafeTensors, SaveSafeTensors,
)]
#[derive(Default, Clone, Debug, ResetParams, ZeroGrads, UpdateParams)]
#[cfg_attr(feature = "safetensors", derive(SaveSafeTensors, LoadSafeTensors))]
pub struct GeneralizedMul<T, U> {
#[module]
#[serialize]
#[cfg_attr(feature = "safetensors", serialize)]
pub t: T,
#[module]
#[serialize]
#[cfg_attr(feature = "safetensors", serialize)]
pub u: U,
}

Expand Down
Loading

0 comments on commit b1a6e67

Please sign in to comment.