diff --git a/dfdx-derives/src/lib.rs b/dfdx-derives/src/lib.rs index 60da4982..13c03c75 100644 --- a/dfdx-derives/src/lib.rs +++ b/dfdx-derives/src/lib.rs @@ -489,23 +489,44 @@ pub fn sequential(input: proc_macro::TokenStream) -> proc_macro::TokenStream { }; let impl_module = { - let src = match input.data { + let (src, src_mut) = match input.data { Data::Struct(ref data) => match data.fields { Fields::Named(ref fields) => { - let recurse = fields.named.iter().map(|f| { - let name = &f.ident; - quote_spanned! {f.span()=> self.#name.try_forward(x)? } - }); - quote! { #(let x = #recurse;)* } + let (recurse, recurse_mut) = fields + .named + .iter() + .map(|f| { + let name = &f.ident; + ( + quote_spanned! {f.span()=> self.#name.try_forward(x)? }, + quote_spanned! {f.span()=> self.#name.try_forward_mut(x)? }, + ) + }) + .unzip::<_, _, Vec<_>, Vec<_>>(); + ( + quote! { #(let x = #recurse;)* }, + quote! { #(let x = #recurse_mut;)* }, + ) } Fields::Unnamed(ref fields) => { - let recurse = fields.unnamed.iter().enumerate().map(|(i, f)| { - let index = Index::from(i); - quote_spanned! {f.span()=> self.#index.try_forward(x)? } - }); - quote! { #(let x = #recurse;)* } + let (recurse, recurse_mut) = fields + .unnamed + .iter() + .enumerate() + .map(|(i, f)| { + let index = Index::from(i); + ( + quote_spanned! {f.span()=> self.#index.try_forward(x)? }, + quote_spanned! {f.span()=> self.#index.try_forward_mut(x)? }, + ) + }) + .unzip::<_, _, Vec<_>, Vec<_>>(); + ( + quote! { #(let x = #recurse;)* }, + quote! { #(let x = #recurse_mut;)* }, + ) } - Fields::Unit => quote! { let x = x; }, + Fields::Unit => (quote! { let x = x; }, quote! { let x = x; }), }, _ => unreachable!(), }; @@ -520,6 +541,10 @@ pub fn sequential(input: proc_macro::TokenStream) -> proc_macro::TokenStream { #src Ok(x) } + fn try_forward_mut(&mut self, x: Input) -> Result { + #src_mut + Ok(x) + } } } }; diff --git a/dfdx/src/nn/layers/batch_norm2d.rs b/dfdx/src/nn/layers/batch_norm2d.rs index c6f592d3..772923d7 100644 --- a/dfdx/src/nn/layers/batch_norm2d.rs +++ b/dfdx/src/nn/layers/batch_norm2d.rs @@ -311,4 +311,22 @@ mod tests { let mut opt = crate::nn::optim::Sgd::new(&bn, Default::default()); opt.update(&mut bn, &g).expect(""); } + + #[derive(Default, Clone, Sequential)] + struct Arch { + pub batch: BatchNorm2DConstConfig<3>, + } + + #[test] + fn test_batchnorm2d_update_with_derive() { + let dev: TestDevice = Default::default(); + + let x1: Tensor, TestDtype, _> = dev.sample_normal(); + let mut bn = dev.build_module::(Arch::default()); + let y = bn.forward_mut(x1.leaky_trace()); + let g = y.square().mean().backward(); + + let mut opt = crate::nn::optim::Sgd::new(&bn, Default::default()); + opt.update(&mut bn, &g).expect(""); + } }