Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add forward_mut in Sequential derive #884

Merged
merged 3 commits into from
Nov 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 37 additions & 12 deletions dfdx-derives/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
/// 2. [dfdx_core::nn_traits::ResetParams]
/// 3. [dfdx_core::nn_traits::UpdateParams]
/// 4. [dfdx_core::nn_traits::ZeroGrads]
/// 5. [dfdx_core::nn_traits::SaveSafeTensors]

Check warning on line 16 in dfdx-derives/src/lib.rs

View workflow job for this annotation

GitHub Actions / cargo-check

unresolved link to `dfdx_core::nn_traits::SaveSafeTensors`
/// 6. [dfdx_core::nn_traits::LoadSafeTensors]

Check warning on line 17 in dfdx-derives/src/lib.rs

View workflow job for this annotation

GitHub Actions / cargo-check

unresolved link to `dfdx_core::nn_traits::LoadSafeTensors`
///
/// If your struct contains sub module configs, then you must add the `#[module]` attribute to those items. Any field that is marked with `#[module]` will be expected to implement [dfdx_core::nn_traits::BuildOnDevice].
///
Expand Down Expand Up @@ -489,23 +489,44 @@
};

let impl_module = {
let src = match input.data {
let (src, src_mut) = match input.data {
Data::Struct(ref data) => match data.fields {
Fields::Named(ref fields) => {
let recurse = fields.named.iter().map(|f| {
let name = &f.ident;
quote_spanned! {f.span()=> self.#name.try_forward(x)? }
});
quote! { #(let x = #recurse;)* }
let (recurse, recurse_mut) = fields
.named
.iter()
.map(|f| {
let name = &f.ident;
(
quote_spanned! {f.span()=> self.#name.try_forward(x)? },
quote_spanned! {f.span()=> self.#name.try_forward_mut(x)? },
)
})
.unzip::<_, _, Vec<_>, Vec<_>>();
(
quote! { #(let x = #recurse;)* },
quote! { #(let x = #recurse_mut;)* },
)
}
Fields::Unnamed(ref fields) => {
let recurse = fields.unnamed.iter().enumerate().map(|(i, f)| {
let index = Index::from(i);
quote_spanned! {f.span()=> self.#index.try_forward(x)? }
});
quote! { #(let x = #recurse;)* }
let (recurse, recurse_mut) = fields
.unnamed
.iter()
.enumerate()
.map(|(i, f)| {
let index = Index::from(i);
(
quote_spanned! {f.span()=> self.#index.try_forward(x)? },
quote_spanned! {f.span()=> self.#index.try_forward_mut(x)? },
)
})
.unzip::<_, _, Vec<_>, Vec<_>>();
(
quote! { #(let x = #recurse;)* },
quote! { #(let x = #recurse_mut;)* },
)
}
Fields::Unit => quote! { let x = x; },
Fields::Unit => (quote! { let x = x; }, quote! { let x = x; }),
},
_ => unreachable!(),
};
Expand All @@ -520,6 +541,10 @@
#src
Ok(x)
}
fn try_forward_mut(&mut self, x: Input) -> Result<Self::Output, Error> {
#src_mut
Ok(x)
}
}
}
};
Expand Down
18 changes: 18 additions & 0 deletions dfdx/src/nn/layers/batch_norm2d.rs
Original file line number Diff line number Diff line change
Expand Up @@ -311,4 +311,22 @@ mod tests {
let mut opt = crate::nn::optim::Sgd::new(&bn, Default::default());
opt.update(&mut bn, &g).expect("");
}

#[derive(Default, Clone, Sequential)]
struct Arch {
pub batch: BatchNorm2DConstConfig<3>,
}

#[test]
fn test_batchnorm2d_update_with_derive() {
let dev: TestDevice = Default::default();

let x1: Tensor<Rank3<3, 4, 5>, TestDtype, _> = dev.sample_normal();
let mut bn = dev.build_module::<TestDtype>(Arch::default());
let y = bn.forward_mut(x1.leaky_trace());
let g = y.square().mean().backward();

let mut opt = crate::nn::optim::Sgd::new(&bn, Default::default());
opt.update(&mut bn, &g).expect("");
}
}
Loading