From 20a958df92f6306f966c1b1015b83c9a84ad38e1 Mon Sep 17 00:00:00 2001 From: Thiago Machado Date: Mon, 6 Nov 2023 10:25:39 -0300 Subject: [PATCH 1/3] add test for try_forward_mut (fails) --- dfdx/src/nn/layers/batch_norm2d.rs | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/dfdx/src/nn/layers/batch_norm2d.rs b/dfdx/src/nn/layers/batch_norm2d.rs index c6f592d3..772923d7 100644 --- a/dfdx/src/nn/layers/batch_norm2d.rs +++ b/dfdx/src/nn/layers/batch_norm2d.rs @@ -311,4 +311,22 @@ mod tests { let mut opt = crate::nn::optim::Sgd::new(&bn, Default::default()); opt.update(&mut bn, &g).expect(""); } + + #[derive(Default, Clone, Sequential)] + struct Arch { + pub batch: BatchNorm2DConstConfig<3>, + } + + #[test] + fn test_batchnorm2d_update_with_derive() { + let dev: TestDevice = Default::default(); + + let x1: Tensor, TestDtype, _> = dev.sample_normal(); + let mut bn = dev.build_module::(Arch::default()); + let y = bn.forward_mut(x1.leaky_trace()); + let g = y.square().mean().backward(); + + let mut opt = crate::nn::optim::Sgd::new(&bn, Default::default()); + opt.update(&mut bn, &g).expect(""); + } } From c96e05dd8d0db98db417074e9b3afd22613ee305 Mon Sep 17 00:00:00 2001 From: Thiago Machado Date: Mon, 6 Nov 2023 10:28:25 -0300 Subject: [PATCH 2/3] add forward_mut to Sequential derive --- dfdx-derives/src/lib.rs | 49 +++++++++++++++++++++++++++++++---------- 1 file changed, 37 insertions(+), 12 deletions(-) diff --git a/dfdx-derives/src/lib.rs b/dfdx-derives/src/lib.rs index 60da4982..45e0be50 100644 --- a/dfdx-derives/src/lib.rs +++ b/dfdx-derives/src/lib.rs @@ -489,23 +489,44 @@ pub fn sequential(input: proc_macro::TokenStream) -> proc_macro::TokenStream { }; let impl_module = { - let src = match input.data { + let (src, src_mut) = match input.data { Data::Struct(ref data) => match data.fields { Fields::Named(ref fields) => { - let recurse = fields.named.iter().map(|f| { - let name = &f.ident; - quote_spanned! {f.span()=> self.#name.try_forward(x)? } - }); - quote! { #(let x = #recurse;)* } + let (recurse, recurse_mut) = fields + .named + .iter() + .map(|f| { + let name = &f.ident; + ( + quote_spanned! {f.span()=> self.#name.try_forward(x)? }, + quote_spanned! {f.span()=> self.#name.try_forward_mut(x)? }, + ) + }) + .unzip::<_, _, Vec<_>, Vec<_>>(); + ( + quote! { #(let x = #recurse;)* }, + quote! { #(let x = #recurse_mut;)* }, + ) } Fields::Unnamed(ref fields) => { - let recurse = fields.unnamed.iter().enumerate().map(|(i, f)| { - let index = Index::from(i); - quote_spanned! {f.span()=> self.#index.try_forward(x)? } - }); - quote! { #(let x = #recurse;)* } + let (recurse, recurse_mut) = fields + .unnamed + .iter() + .enumerate() + .map(|(i, f)| { + let index = Index::from(i); + ( + quote_spanned! {f.span()=> self.#index.try_forward(x)? }, + quote_spanned! {f.span()=> self.#index.try_forward_mut(x)? }, + ) + }) + .unzip::<_, _, Vec<_>, Vec<_>>(); + ( + quote! { #(let x = #recurse;)* }, + quote! { #(let x = #recurse;)* }, + ) } - Fields::Unit => quote! { let x = x; }, + Fields::Unit => (quote! { let x = x; }, quote! { let x = x; }), }, _ => unreachable!(), }; @@ -520,6 +541,10 @@ pub fn sequential(input: proc_macro::TokenStream) -> proc_macro::TokenStream { #src Ok(x) } + fn try_forward_mut(&mut self, x: Input) -> Result { + #src_mut + Ok(x) + } } } }; From b298a98eae178ee4a960bad52d196c536f4d6b00 Mon Sep 17 00:00:00 2001 From: Thiago Machado Date: Mon, 6 Nov 2023 10:42:38 -0300 Subject: [PATCH 3/3] also use the _mut for tuple structs --- dfdx-derives/src/lib.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dfdx-derives/src/lib.rs b/dfdx-derives/src/lib.rs index 45e0be50..13c03c75 100644 --- a/dfdx-derives/src/lib.rs +++ b/dfdx-derives/src/lib.rs @@ -523,7 +523,7 @@ pub fn sequential(input: proc_macro::TokenStream) -> proc_macro::TokenStream { .unzip::<_, _, Vec<_>, Vec<_>>(); ( quote! { #(let x = #recurse;)* }, - quote! { #(let x = #recurse;)* }, + quote! { #(let x = #recurse_mut;)* }, ) } Fields::Unit => (quote! { let x = x; }, quote! { let x = x; }),