From c96e05dd8d0db98db417074e9b3afd22613ee305 Mon Sep 17 00:00:00 2001 From: Thiago Machado Date: Mon, 6 Nov 2023 10:28:25 -0300 Subject: [PATCH] add forward_mut to Sequential derive --- dfdx-derives/src/lib.rs | 49 +++++++++++++++++++++++++++++++---------- 1 file changed, 37 insertions(+), 12 deletions(-) diff --git a/dfdx-derives/src/lib.rs b/dfdx-derives/src/lib.rs index 60da4982..45e0be50 100644 --- a/dfdx-derives/src/lib.rs +++ b/dfdx-derives/src/lib.rs @@ -489,23 +489,44 @@ pub fn sequential(input: proc_macro::TokenStream) -> proc_macro::TokenStream { }; let impl_module = { - let src = match input.data { + let (src, src_mut) = match input.data { Data::Struct(ref data) => match data.fields { Fields::Named(ref fields) => { - let recurse = fields.named.iter().map(|f| { - let name = &f.ident; - quote_spanned! {f.span()=> self.#name.try_forward(x)? } - }); - quote! { #(let x = #recurse;)* } + let (recurse, recurse_mut) = fields + .named + .iter() + .map(|f| { + let name = &f.ident; + ( + quote_spanned! {f.span()=> self.#name.try_forward(x)? }, + quote_spanned! {f.span()=> self.#name.try_forward_mut(x)? }, + ) + }) + .unzip::<_, _, Vec<_>, Vec<_>>(); + ( + quote! { #(let x = #recurse;)* }, + quote! { #(let x = #recurse_mut;)* }, + ) } Fields::Unnamed(ref fields) => { - let recurse = fields.unnamed.iter().enumerate().map(|(i, f)| { - let index = Index::from(i); - quote_spanned! {f.span()=> self.#index.try_forward(x)? } - }); - quote! { #(let x = #recurse;)* } + let (recurse, recurse_mut) = fields + .unnamed + .iter() + .enumerate() + .map(|(i, f)| { + let index = Index::from(i); + ( + quote_spanned! {f.span()=> self.#index.try_forward(x)? }, + quote_spanned! {f.span()=> self.#index.try_forward_mut(x)? }, + ) + }) + .unzip::<_, _, Vec<_>, Vec<_>>(); + ( + quote! { #(let x = #recurse;)* }, + quote! { #(let x = #recurse;)* }, + ) } - Fields::Unit => quote! { let x = x; }, + Fields::Unit => (quote! { let x = x; }, quote! { let x = x; }), }, _ => unreachable!(), }; @@ -520,6 +541,10 @@ pub fn sequential(input: proc_macro::TokenStream) -> proc_macro::TokenStream { #src Ok(x) } + fn try_forward_mut(&mut self, x: Input) -> Result { + #src_mut + Ok(x) + } } } };