From 86cec344193debc98b9f8cb13f915cb76f12f10c Mon Sep 17 00:00:00 2001 From: Thiago Machado Date: Sun, 5 Nov 2023 22:09:02 -0300 Subject: [PATCH 1/7] - Add `#[input_wrapper]`. - Add the heck dep to convert from CamelCase into snake_case. - Add layers. - `Id`, which just forwards the input. - `On`, applies some Module into an input wrapper field. - Contains a test demonstrating it's usage. - `Add`, which calls `try_add` for the inputs. --- dfdx-derives/Cargo.toml | 1 + dfdx-derives/src/lib.rs | 454 ++++++++++++++++++++++++++++++++++ dfdx/src/nn/layers/id.rs | 19 ++ dfdx/src/nn/layers/mod.rs | 7 + dfdx/src/nn/layers/on.rs | 132 ++++++++++ dfdx/src/nn/layers/ops/add.rs | 20 ++ dfdx/src/nn/layers/ops/mod.rs | 3 + 7 files changed, 636 insertions(+) create mode 100644 dfdx/src/nn/layers/id.rs create mode 100644 dfdx/src/nn/layers/on.rs create mode 100644 dfdx/src/nn/layers/ops/add.rs create mode 100644 dfdx/src/nn/layers/ops/mod.rs diff --git a/dfdx-derives/Cargo.toml b/dfdx-derives/Cargo.toml index 9941edbd..c142e70f 100644 --- a/dfdx-derives/Cargo.toml +++ b/dfdx-derives/Cargo.toml @@ -13,6 +13,7 @@ proc-macro2 = "1" quote = "1" syn = { version = "2", features = ["extra-traits"] } dfdx-core = { path = "../dfdx-core" } +heck = "0.4.1" [features] nightly = ["dfdx-core/nightly"] diff --git a/dfdx-derives/src/lib.rs b/dfdx-derives/src/lib.rs index 4eca0d82..442fb690 100644 --- a/dfdx-derives/src/lib.rs +++ b/dfdx-derives/src/lib.rs @@ -957,3 +957,457 @@ pub fn load_safetensors(input: proc_macro::TokenStream) -> proc_macro::TokenStre } }) } + +/// Generates a module containing helpful structs and implementations for a input wrapper. +/// +/// ## Example +/// +/// The following definition: +/// ```ignore +/// #[input_wrapper] +/// pub struct MyWrapper { +/// pub a: A, +/// pub b: B, +/// } +/// ``` +/// +/// Generates the following module: +/// ```ignore +/// pub mod my_wrapper { +/// // structs for the fields +/// pub struct a; +/// pub struct b; +/// // note: if MyWrapper was a tuple-struct, +/// // the fields would be named _0, _1 and so on +/// +/// // structs to help in tuple conversions (Module impls omitted) +/// pub struct FromTuple; +/// pub struct IntoTuple; +/// +/// // access for the `a` field +/// impl, A, B> Module> for On +/// { +/// type Output = MyWrapper<>::Output, B>; +/// fn try_forward(&self, x: MyWrapper) -> Result {/* (...) */} +/// fn try_forward_mut(&mut self, x: MyWrapper) -> Result {/* (...) */} +/// } +/// +/// // access for the `b` field +/// impl, A, B> Module> for On +/// { +/// type Output = MyWrapper>::Output>; +/// fn try_forward(&self, x: MyWrapper) -> Result {/* ... */} +/// fn try_forward_mut(&mut self, x: MyWrapper) -> Result {/* ... */} +/// } +/// } +/// ``` +/// To better visualize the generated code and items, it's recommended to expand it with Rust-Analyzer, +/// or to generate the project's documentation. +/// +/// Those helpers can then be used as modules: +/// ```ignore +/// #[derive(Default, Clone, Sequential)] +/// pub struct Arch { +/// // (...) +/// +/// // assuming Input is of type (X, Y), converts the input into MyWrapper +/// pub input_to_wrapper: my_wrapper::FromTuple, +/// +/// // apply module T on the field `a`, while also mapping the input into: +/// // MyWrapper<>::Output, Y> +/// pub t: On, +/// +/// // converts the input into a tuple: +/// // (>::Output, Y) +/// pub input_to_tuple: split1::IntoTuple, +/// +/// // (...) +/// } +/// ``` +#[proc_macro_attribute] +pub fn input_wrapper( + _attr: proc_macro::TokenStream, + input: proc_macro::TokenStream, +) -> proc_macro::TokenStream { + let wrapper = parse_macro_input!(input as DeriveInput); + + // - TODO: any bounds on the struct definition probably should be copied into other impls. + // - NOTE: check on how to deal with the `Self` (as this won't refer to the struct on the On Module impl). + + // eg. MyWrapper + let wrapper_ident = wrapper.ident.clone(); + let wrapper_vis = wrapper.vis.clone(); + let wrapper_lowercase = format!("{}", heck::AsSnakeCase(wrapper_ident.to_string())); + // eg. my_wrapper + // TODO: allow renaming + let wrapper_lowercase_ident = syn::Ident::new(&wrapper_lowercase, wrapper_ident.span()); + + // get wrapper field info + // eg. [(pub, Some(my_field), MyFieldType, field span)] + let mut wrapper_fields = vec![]; + match &wrapper.data { + Data::Struct(ref obj) => match obj.fields { + Fields::Named(ref fields) => { + let fields = fields.named.iter().map(|f| { + let ty = &f.ty; + assert_ne!( + quote!(#ty).to_string(), + "M", + "A generic type named `M` is not allowed because this is used internally" + ); + (&f.vis, &f.ident, &f.ty, f.span()) + }); + wrapper_fields.extend(fields) + } + Fields::Unnamed(ref fields) => { + let fields = fields.unnamed.iter().map(|f: &syn::Field| { + let ty = &f.ty; + assert_ne!( + quote!(#ty).to_string(), + "M", + "A generic type named `M` is not allowed because this is used internally" + ); + (&f.vis, &None, &f.ty, f.span()) + }); + wrapper_fields.extend(fields) + } + // no fields + Fields::Unit => {} + }, + Data::Enum(_) => unimplemented!("Input wrapper cannot be derived for enums."), + Data::Union(_) => unimplemented!("Input wrapper cannot be derived for unions."), + }; + + // wrapper fields as structs + let mut wrapper_field_structs_quote = vec![]; + let mut are_fields_named = false; + for (i, (_vis, field, _ty, span)) in wrapper_fields.iter().enumerate() { + let (doc, field) = if let Some(field) = field { + are_fields_named = true; + let doc = format!( + "Indicates the [`{}::{}`] field. \nThis field is the `{}` value (`0`-based index).", + wrapper_ident, + field, + i + ); + (doc, field.clone()) + } else { + let doc = format!( + "Indicates the `{}`-th value from [`{}`] (0-based index).", + i, wrapper_ident + ); + let field = syn::Ident::new(&format!("_{}", i), *span); + (doc, field) + }; + wrapper_field_structs_quote.push(quote! { + #[doc = #doc] + #[allow(non_camel_case_types)] + #[derive(Debug, Default, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] + pub struct #field; + }); + } + + let imports = if are_fields_named { + quote! { + use super::#wrapper_ident; + } + } else { + quote! { + use super::#wrapper_ident; + // TODO: import tuple stuff + use crate::prelude; + } + }; + + let wrapper_generics = wrapper.generics.clone(); + let wrapper_generics_params = wrapper_generics.params.iter().collect::>(); + // eg. MyWrapper -> [A, B] + let wrapper_generics_param_idents = { + wrapper_generics_params + .iter() + .map(|p| { + use syn::GenericParam::*; + match p { + Lifetime(l) => &l.lifetime.ident, + Type(t) => &t.ident, + Const(c) => &c.ident, + } + }) + .collect::>() + }; + + // eg. MyWrapper -> [A, B] + let wrapper_generic_names = wrapper_generics_param_idents.iter().collect::>(); + + // eg. MyWrapper { field1: A, field2: bool} -> [A, bool] + let field_ty_names = wrapper_fields + .iter() + .map(|(_, _, ty, _)| ty) + .collect::>(); + + // create structs to represent tuple conversions + let tuple_conversion_structs = { + let field_ty_names = field_ty_names + .iter() + .map(|ty| quote! {#ty}.to_string()) + .collect::>() + .join(", "); + let wrapper_generic_names = wrapper_generic_names + .iter() + .map(|ident| ident.to_string()) + .collect::>() + .join(", "); + let doc1 = format!( + "Indicates a conversion from a ({}) tuple into a `{}<{}>`.", + &field_ty_names, wrapper_ident, &wrapper_generic_names + ); + let doc2 = format!( + "Indicates a conversion from a `{}<{}>` into a ({}) tuple.", + wrapper_ident, &wrapper_generic_names, &field_ty_names + ); + quote! { + #[doc = #doc1] + #[derive(Debug, Default, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, crate::prelude::CustomModule)] + pub struct FromTuple; + #[doc = #doc2] + #[derive(Debug, Default, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, crate::prelude::CustomModule)] + pub struct IntoTuple; + } + }; + + // impl From<> conversions + let tuple_conversions = { + let doc1 = format!("Conversion of a tuple into a [`{}`].", wrapper_ident,); + let doc2 = format!("Conversion of a [`{}`] into a tuple.", wrapper_ident,); + + let mut field_from_tuple = vec![]; + let mut field_to_tuple = vec![]; + for (i, (_, ident, _, _span)) in wrapper_fields.iter().enumerate() { + let i = syn::Index::from(i); + if let Some(ident) = ident { + field_from_tuple.push(quote! {#ident: x.#i}); + field_to_tuple.push(quote! {x.#ident}); + } else { + field_from_tuple.push(quote! {x.#i}); + field_to_tuple.push(quote! {x.#i}); + }; + } + + let (from_tuple, to_tuple) = if are_fields_named { + ( + quote! { + #wrapper_ident { + #(#field_from_tuple), * + } + }, + quote! { (#(#field_to_tuple), *) }, + ) + } else { + ( + quote! { + #wrapper_ident ( + #(#field_from_tuple), * + ) + }, + quote! {(#(#field_to_tuple), *)}, + ) + }; + + quote! { + #[doc = #doc1] + impl<#(#wrapper_generic_names), *> From<(#(#field_ty_names), *)> for #wrapper_ident<#(#wrapper_generic_names), *> { + fn from(x: (#(#field_ty_names), *)) -> Self { + #from_tuple + } + } + #[doc = #doc2] + impl<#(#wrapper_generic_names), *> From<#wrapper_ident<#(#wrapper_generic_names), *>> for (#(#field_ty_names), *) { + fn from(x: #wrapper_ident<#(#wrapper_generic_names), *>) -> Self { + #to_tuple + } + } + } + }; + + // impl Module for conversions into and from tuples + let module_conversions = { + let doc1 = format!("Module to convert a tuple into a [`{}`].", wrapper_ident,); + let doc2 = format!("Module to convert a [`{}`] into a tuple.", wrapper_ident,); + quote! { + #[doc = #doc1] + impl<#(#wrapper_generic_names), *> crate::prelude::Module<(#(#field_ty_names), *)> for FromTuple { + type Output = #wrapper_ident<#(#wrapper_generic_names), *>; + fn try_forward(&self, x: (#(#field_ty_names), *)) -> Result { + Ok(x.into()) + } + } + #[doc = #doc2] + impl<#(#wrapper_generic_names), *> crate::prelude::Module<#wrapper_ident<#(#wrapper_generic_names), *>> for IntoTuple { + type Output = (#(#field_ty_names), *); + fn try_forward(&self, x: #wrapper_ident<#(#wrapper_generic_names), *>) -> Result { + Ok(x.into()) + } + } + } + }; + + // assertion + for generic_ident in wrapper_generic_names.iter() { + let count = wrapper_fields + .iter() + .map(|(_, _, field_ty, _)| is_ident_container(generic_ident, field_ty)) + .filter(|contains| *contains) + .count(); + if count > 1 { + panic!("the generic {generic_ident} should be used in at most one field"); + } + } + + // field access modules + let mut field_access_modules = vec![]; + for (i, (_vis, ident, ty, span)) in wrapper_fields.iter().enumerate() { + let (doc, on_acccess, forward) = if let Some(ident) = ident { + let doc = format!( + "Module that access [`{}::{}`] and then applies Module `M` on it.", + wrapper_ident, ident + ); + let on_access = ident.clone(); + let forward = syn::Ident::new(&format!("x{i}"), ident.span()); + (doc, on_access, forward) + } else { + let doc = format!( + "Module that access the `{}`-th value from [`{}`] and then applies Module `M` on it.", + i, + wrapper_ident, + ); + let on_access = syn::Ident::new(&format!("_{}", i), *span); + let forward = syn::Ident::new(&format!("x{i}"), *span); + (doc, on_access, forward) + }; + + let mut contains_ident = false; + let output_generics = wrapper_generic_names.iter().map(|ty_ident| { + // + if is_ident_container(ty_ident, ty) { + if contains_ident { + panic!( + "the field {ident:?} at index {i} should contain at most one generic type" + ); + } + contains_ident = true; + quote!(>::Output) + } else { + quote!(#ty_ident) + } + }); + + let mut field_extraction_idents = vec![]; + let mut field_extraction = vec![]; + let mut field_construction = vec![]; + for (i, (_, _ident, _, span)) in wrapper_fields.iter().enumerate() { + let ii = syn::Index::from(i); + if let Some(_ident) = _ident { + let xident = syn::Ident::new(&format!("x{i}"), _ident.span()); + field_extraction_idents.push(xident.clone()); + field_extraction.push(quote! {let #xident = x.#_ident;}); + field_construction.push(quote! {#_ident: #xident,}); + } else { + let xident = syn::Ident::new(&format!("x{i}"), *span); + field_extraction_idents.push(xident.clone()); + field_extraction.push(quote! {let #xident = x.#ii;}); + field_construction.push(quote! {#xident,}); + }; + } + let field_replacement = if are_fields_named { + quote! { + #wrapper_ident { + #(#field_construction)* + } + } + } else { + quote! { + #wrapper_ident ( + #(#field_construction)* + ) + } + }; + + let field_access_module = quote! { + #[doc = #doc] + impl, #(#wrapper_generic_names), *> crate::prelude::Module<#wrapper_ident<#(#wrapper_generic_names), *>> for crate::prelude::On<#on_acccess, M> { + type Output = #wrapper_ident<#(#output_generics), *>; + fn try_forward(&self, x: #wrapper_ident<#(#wrapper_generic_names), *>) -> Result { + #(#field_extraction)* + let #forward = self.t.try_forward(#forward)?; + let x = #field_replacement; + Ok(x) + } + fn try_forward_mut(&mut self, x: #wrapper_ident<#(#wrapper_generic_names), *>) -> Result { + #(#field_extraction)* + let #forward = self.t.try_forward_mut(#forward)?; + let x = #field_replacement; + Ok(x) + } + } + }; + field_access_modules.push(field_access_module); + } + + // all of the generated content + let _mod = quote! { + #wrapper_vis mod #wrapper_lowercase_ident { + #imports + + #(#wrapper_field_structs_quote)* + + #tuple_conversion_structs + + #tuple_conversions + + #module_conversions + + #(#field_access_modules)* + } + }; + let output = quote!( + #wrapper + + #[doc = "Automatically generated by `input_wrapper`. The containing items are visible on your project's documentation."] + #_mod + ); + proc_macro::TokenStream::from(output) +} + +/// Checks whether `ty` contains any ident that matches the `ident`. +fn is_ident_container(ident: &syn::Ident, ty: &syn::Type) -> bool { + match ty { + syn::Type::Array(_) => todo!("input_wrapper is_ident_container for array"), + syn::Type::BareFn(_) => todo!("input_wrapper is_ident_container for bare fn"), + syn::Type::Group(_) => todo!("input_wrapper is_ident_container for group"), + syn::Type::ImplTrait(_) => todo!("input_wrapper is_ident_container for impl trait"), + syn::Type::Infer(_) => todo!("input_wrapper is_ident_container for infer"), + syn::Type::Macro(_) => todo!("input_wrapper is_ident_container for macro"), + syn::Type::Never(_) => todo!("input_wrapper is_ident_container for never"), + syn::Type::Paren(_) => todo!("input_wrapper is_ident_container for paren"), + syn::Type::Path(ty) => { + let mut is = false; + if let Some(qself) = &ty.qself { + is |= is_ident_container(ident, &qself.ty); + } + if let Some(segment) = &ty.path.segments.last() { + is |= &segment.ident == ident; + } + is + } + syn::Type::Ptr(_) => todo!("input_wrapper is_ident_container for ptr"), + syn::Type::Reference(_) => todo!("input_wrapper is_ident_container for reference"), + syn::Type::Slice(_) => todo!("input_wrapper is_ident_container for slice"), + syn::Type::TraitObject(_) => todo!("input_wrapper is_ident_container for trait object"), + syn::Type::Tuple(_) => todo!("input_wrapper is_ident_container for tuple"), + syn::Type::Verbatim(_) => todo!("input_wrapper is_ident_container for verbatim"), + other => unimplemented!( + "input_wrapper is_ident_container not implemented for {}", + quote!(#other).to_string() + ), + } +} diff --git a/dfdx/src/nn/layers/id.rs b/dfdx/src/nn/layers/id.rs new file mode 100644 index 00000000..e950e551 --- /dev/null +++ b/dfdx/src/nn/layers/id.rs @@ -0,0 +1,19 @@ +use crate::prelude::*; + +/// Forwards the input to the output. +#[derive(Default, Debug, Clone, Copy, CustomModule)] +pub struct Id; + +impl, T: Tape> Module> for Id { + type Output = Tensor; + fn try_forward(&self, x: Tensor) -> Result { + Ok(x) + } +} + +pub type Id1 = (Id,); +pub type Id2 = (Id, Id); +pub type Id3 = (Id, Id, Id); +pub type Id4 = (Id, Id, Id, Id); +pub type Id5 = (Id, Id, Id, Id, Id); +pub type Id6 = (Id, Id, Id, Id, Id, Id); diff --git a/dfdx/src/nn/layers/mod.rs b/dfdx/src/nn/layers/mod.rs index 828b1e97..a3f50da3 100644 --- a/dfdx/src/nn/layers/mod.rs +++ b/dfdx/src/nn/layers/mod.rs @@ -1,3 +1,5 @@ +pub mod ops; + mod abs; mod add_into; mod batch_norm1d; @@ -19,6 +21,8 @@ mod flatten2d; mod gelu; mod generalized_add; mod generalized_mul; +pub mod id; +mod input_into; mod layer_norm1d; mod leaky_relu; mod linear; @@ -26,6 +30,7 @@ mod ln; mod log_softmax; mod matmul; mod multi_head_attention; +mod on; #[cfg(feature = "nightly")] mod pool_2d_avg; #[cfg(feature = "nightly")] @@ -72,6 +77,7 @@ pub use flatten2d::Flatten2D; pub use gelu::{AccurateGeLU, FastGeLU}; pub use generalized_add::GeneralizedAdd; pub use generalized_mul::GeneralizedMul; +pub use id::Id; pub use layer_norm1d::{LayerNorm1D, LayerNorm1DConfig, LayerNorm1DConstConfig}; pub use leaky_relu::LeakyReLU; pub use linear::{Linear, LinearConfig, LinearConstConfig}; @@ -79,6 +85,7 @@ pub use ln::Ln; pub use log_softmax::LogSoftmax; pub use matmul::{MatMul, MatMulConfig, MatMulConstConfig}; pub use multi_head_attention::{MultiHeadAttention, MultiHeadAttentionConfig}; +pub use on::On; #[cfg(feature = "nightly")] pub use pool_2d_avg::{AvgPool2D, AvgPool2DConst}; #[cfg(feature = "nightly")] diff --git a/dfdx/src/nn/layers/on.rs b/dfdx/src/nn/layers/on.rs new file mode 100644 index 00000000..eedec78c --- /dev/null +++ b/dfdx/src/nn/layers/on.rs @@ -0,0 +1,132 @@ +use crate::prelude::*; +use std::marker::PhantomData; + +// TODO: try making a Call module, whih allows calling an arbitrary method on the input. + +/// Access the input that is stored in a wrapper structure. +#[derive( + Debug, Default, Clone, ResetParams, ZeroGrads, UpdateParams, LoadSafeTensors, SaveSafeTensors, +)] +#[repr(transparent)] +pub struct On { + #[module] + #[serialize] + pub t: T, + + pub _n: PhantomData, +} + +impl, N: Clone + std::fmt::Debug, T: BuildOnDevice> BuildOnDevice + for On +{ + type Built = On; + fn try_build_on_device(&self, device: &D) -> Result { + let t = self.t.try_build_on_device(device)?; + Ok(On { t, _n: PhantomData }) + } +} + +// TODO: define On access for standard tuples, +// so that it's possible to access them with something like: +// On +pub mod tuple {} + +// cargo 'test' '--package' 'dfdx' '--lib' '--' 'nn::layers::on::tests' '--nocapture' +// test based on nn/layers/residual_add.rs +#[cfg(test)] +mod tests { + use super::*; + use crate::tests::*; + + #[input_wrapper] + pub struct MyWrapper { + pub a: A, + pub b: B, + } + + + #[input_wrapper] + pub struct Split1 { + pub forward: Forward, + pub skip: Skip, + } + + #[derive(Default, Clone, Sequential)] + pub struct ResidualAdd1 { + // input is Input + pub split: SplitInto<(Id, Id)>, + + // input is (Input, Input) + pub input_to_wrapper: split1::FromTuple, + + // input is Split1 { Input, Input } + pub t: On, + + // input is Split1 { T::Output, Input } + pub input_to_tuple: split1::IntoTuple, + + // input is (T::Output, Input) + pub add: ops::Add, + // input is T::Output = Input + } + + #[test] + fn test_residual_add_backward() { + let dev: TestDevice = Default::default(); + + let model = dev.build_module::(>>::default()); + + let x: Tensor, f32, _> = dev.sample_normal(); + let x = x.to_dtype::(); + let y = model.forward(x.leaky_trace()); + + #[rustfmt::skip] + assert_close_to_literal!(y, [[0.25372928, -2.4258814],[1.7892148, -2.6242268],[1.5131638, 0.23407778],[3.4201493, 1.597525]]); + + let g = y.mean().backward(); + assert_close_to_literal!(g.get(&model.t.t.weight), [[0.475242, -0.075136]; 2]); + assert_close_to_literal!(g.get(&model.t.t.bias), [0.5; 2]); + assert_close_to_literal!(g.get(&x), [[0.18806472, 0.21419683]; 4]); + } + + #[input_wrapper] + pub struct Split2(Forward, Skip); + + #[derive(Default, Clone, Sequential)] + pub struct ResidualAdd2 { + // input is Input + pub split: SplitInto<(Id, Id)>, + + // input is (Input, Input) + pub input_to_wrapper: split2::FromTuple, + + // input is Split2 ( Input, Input ) + pub t: On, + + // input is Split2 ( T::Output, Input ) + pub input_to_tuple: split2::IntoTuple, + + // input is (T::Output, Input) + pub add: ops::Add, + // input is T::Output = Input + } + + #[test] + fn test_residual_add_backward2() { + let dev: TestDevice = Default::default(); + + let model = dev.build_module::(>>::default()); + + let x: Tensor, f32, _> = dev.sample_normal(); + let x = x.to_dtype::(); + let y = model.forward(x.leaky_trace()); + + #[rustfmt::skip] + assert_close_to_literal!(y, [[0.25372928, -2.4258814],[1.7892148, -2.6242268],[1.5131638, 0.23407778],[3.4201493, 1.597525]]); + + let g = y.mean().backward(); + assert_close_to_literal!(g.get(&model.t.t.weight), [[0.475242, -0.075136]; 2]); + assert_close_to_literal!(g.get(&model.t.t.bias), [0.5; 2]); + assert_close_to_literal!(g.get(&x), [[0.18806472, 0.21419683]; 4]); + } +} diff --git a/dfdx/src/nn/layers/ops/add.rs b/dfdx/src/nn/layers/ops/add.rs new file mode 100644 index 00000000..fcf40142 --- /dev/null +++ b/dfdx/src/nn/layers/ops/add.rs @@ -0,0 +1,20 @@ +use crate::prelude::*; + +/// Calls [crate::tensor_ops::add()] +#[derive(Default, Debug, Clone, Copy, crate::nn::CustomModule)] +pub struct Add; + +// TODO: macro for more tuples +// TODO: lower the requirement, as long as one of the values can be broadcast into the other one +// TODO: check if this works for constants + +impl Module<(Input, Input)> for Add +where + Input: TryAdd, +{ + type Output = ::Output; + + fn try_forward(&self, x: (Input, Input)) -> Result { + x.0.try_add(x.1) + } +} diff --git a/dfdx/src/nn/layers/ops/mod.rs b/dfdx/src/nn/layers/ops/mod.rs new file mode 100644 index 00000000..6ecd1204 --- /dev/null +++ b/dfdx/src/nn/layers/ops/mod.rs @@ -0,0 +1,3 @@ +mod add; + +pub use add::Add; From 2939d1025d845dcb2ff3aaeb43c0c0a5f8fcf557 Mon Sep 17 00:00:00 2001 From: Thiago Machado Date: Sun, 5 Nov 2023 22:28:36 -0300 Subject: [PATCH 2/7] rm missing mod --- dfdx/src/nn/layers/mod.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/dfdx/src/nn/layers/mod.rs b/dfdx/src/nn/layers/mod.rs index a3f50da3..e0dd49c5 100644 --- a/dfdx/src/nn/layers/mod.rs +++ b/dfdx/src/nn/layers/mod.rs @@ -22,7 +22,6 @@ mod gelu; mod generalized_add; mod generalized_mul; pub mod id; -mod input_into; mod layer_norm1d; mod leaky_relu; mod linear; From d04636ddfb4d6942db469b8d842bb7f741ed3701 Mon Sep 17 00:00:00 2001 From: Thiago Machado Date: Sun, 5 Nov 2023 22:50:29 -0300 Subject: [PATCH 3/7] rm newline; update test dtype --- dfdx/src/nn/layers/on.rs | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/dfdx/src/nn/layers/on.rs b/dfdx/src/nn/layers/on.rs index eedec78c..522d1e49 100644 --- a/dfdx/src/nn/layers/on.rs +++ b/dfdx/src/nn/layers/on.rs @@ -44,7 +44,6 @@ mod tests { pub b: B, } - #[input_wrapper] pub struct Split1 { pub forward: Forward, @@ -74,7 +73,8 @@ mod tests { fn test_residual_add_backward() { let dev: TestDevice = Default::default(); - let model = dev.build_module::(>>::default()); + let model = + dev.build_module::(>>::default()); let x: Tensor, f32, _> = dev.sample_normal(); let x = x.to_dtype::(); @@ -115,7 +115,8 @@ mod tests { fn test_residual_add_backward2() { let dev: TestDevice = Default::default(); - let model = dev.build_module::(>>::default()); + let model = + dev.build_module::(>>::default()); let x: Tensor, f32, _> = dev.sample_normal(); let x = x.to_dtype::(); From 9db1466b13fc04b253dcb822ed8bc54530962cde Mon Sep 17 00:00:00 2001 From: Thiago Machado Date: Sun, 5 Nov 2023 23:14:00 -0300 Subject: [PATCH 4/7] fix test for f64 --- dfdx/src/nn/layers/on.rs | 38 ++++++++++++++++++++++++++++---------- 1 file changed, 28 insertions(+), 10 deletions(-) diff --git a/dfdx/src/nn/layers/on.rs b/dfdx/src/nn/layers/on.rs index 522d1e49..d899f6fc 100644 --- a/dfdx/src/nn/layers/on.rs +++ b/dfdx/src/nn/layers/on.rs @@ -38,12 +38,6 @@ mod tests { use super::*; use crate::tests::*; - #[input_wrapper] - pub struct MyWrapper { - pub a: A, - pub b: B, - } - #[input_wrapper] pub struct Split1 { pub forward: Forward, @@ -73,8 +67,20 @@ mod tests { fn test_residual_add_backward() { let dev: TestDevice = Default::default(); - let model = - dev.build_module::(>>::default()); + let model = dev.build_module::(>>::default()); + let model = DeviceResidualAdd1::, TestDtype, TestDevice> { + t: On { + t: Linear { + weight: model.t.t.weight.to_dtype::(), + bias: model.t.t.bias.to_dtype::(), + }, + _n: Default::default(), + }, + add: Default::default(), + input_to_tuple: Default::default(), + input_to_wrapper: Default::default(), + split: Default::default(), + }; let x: Tensor, f32, _> = dev.sample_normal(); let x = x.to_dtype::(); @@ -115,8 +121,20 @@ mod tests { fn test_residual_add_backward2() { let dev: TestDevice = Default::default(); - let model = - dev.build_module::(>>::default()); + let model = dev.build_module::(>>::default()); + let model = DeviceResidualAdd2::, TestDtype, TestDevice> { + t: On { + t: Linear { + weight: model.t.t.weight.to_dtype::(), + bias: model.t.t.bias.to_dtype::(), + }, + _n: Default::default(), + }, + add: Default::default(), + input_to_tuple: Default::default(), + input_to_wrapper: Default::default(), + split: Default::default(), + }; let x: Tensor, f32, _> = dev.sample_normal(); let x = x.to_dtype::(); From 28b504903f51d1f1ef4ff9d3d204bfceb2b9e733 Mon Sep 17 00:00:00 2001 From: Thiago Machado Date: Mon, 6 Nov 2023 10:03:00 -0300 Subject: [PATCH 5/7] generalize Id; test renaming --- dfdx/src/nn/layers/id.rs | 6 +++--- dfdx/src/nn/layers/on.rs | 14 +++++++------- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/dfdx/src/nn/layers/id.rs b/dfdx/src/nn/layers/id.rs index e950e551..b209b693 100644 --- a/dfdx/src/nn/layers/id.rs +++ b/dfdx/src/nn/layers/id.rs @@ -4,9 +4,9 @@ use crate::prelude::*; #[derive(Default, Debug, Clone, Copy, CustomModule)] pub struct Id; -impl, T: Tape> Module> for Id { - type Output = Tensor; - fn try_forward(&self, x: Tensor) -> Result { +impl Module for Id { + type Output = Input; + fn try_forward(&self, x: Input) -> Result { Ok(x) } } diff --git a/dfdx/src/nn/layers/on.rs b/dfdx/src/nn/layers/on.rs index d899f6fc..9d0465cb 100644 --- a/dfdx/src/nn/layers/on.rs +++ b/dfdx/src/nn/layers/on.rs @@ -3,7 +3,7 @@ use std::marker::PhantomData; // TODO: try making a Call module, whih allows calling an arbitrary method on the input. -/// Access the input that is stored in a wrapper structure. +/// Applies module `T` into an input field from a wrapper. #[derive( Debug, Default, Clone, ResetParams, ZeroGrads, UpdateParams, LoadSafeTensors, SaveSafeTensors, )] @@ -52,10 +52,10 @@ mod tests { // input is (Input, Input) pub input_to_wrapper: split1::FromTuple, - // input is Split1 { Input, Input } + // input is Split1 pub t: On, - // input is Split1 { T::Output, Input } + // input is Split1 pub input_to_tuple: split1::IntoTuple, // input is (T::Output, Input) @@ -64,7 +64,7 @@ mod tests { } #[test] - fn test_residual_add_backward() { + fn test_input_wrapper_struct() { let dev: TestDevice = Default::default(); let model = dev.build_module::(>>::default()); @@ -106,10 +106,10 @@ mod tests { // input is (Input, Input) pub input_to_wrapper: split2::FromTuple, - // input is Split2 ( Input, Input ) + // input is Split2 pub t: On, - // input is Split2 ( T::Output, Input ) + // input is Split2 pub input_to_tuple: split2::IntoTuple, // input is (T::Output, Input) @@ -118,7 +118,7 @@ mod tests { } #[test] - fn test_residual_add_backward2() { + fn test_input_wrapper_tuple_struct() { let dev: TestDevice = Default::default(); let model = dev.build_module::(>>::default()); From 82c314b553c57c4a5927627e341138d9f18c23d5 Mon Sep 17 00:00:00 2001 From: Thiago Machado Date: Mon, 4 Dec 2023 17:57:24 -0500 Subject: [PATCH 6/7] updates to current main --- dfdx-derives/src/lib.rs | 20 ++++++++++---------- dfdx/src/nn/layers/on.rs | 7 +++---- 2 files changed, 13 insertions(+), 14 deletions(-) diff --git a/dfdx-derives/src/lib.rs b/dfdx-derives/src/lib.rs index 442fb690..c0782fdb 100644 --- a/dfdx-derives/src/lib.rs +++ b/dfdx-derives/src/lib.rs @@ -1167,10 +1167,10 @@ pub fn input_wrapper( ); quote! { #[doc = #doc1] - #[derive(Debug, Default, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, crate::prelude::CustomModule)] + #[derive(Debug, Default, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, ::dfdx::CustomModule)] pub struct FromTuple; #[doc = #doc2] - #[derive(Debug, Default, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, crate::prelude::CustomModule)] + #[derive(Debug, Default, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, ::dfdx::prelude::CustomModule)] pub struct IntoTuple; } }; @@ -1235,16 +1235,16 @@ pub fn input_wrapper( let doc2 = format!("Module to convert a [`{}`] into a tuple.", wrapper_ident,); quote! { #[doc = #doc1] - impl<#(#wrapper_generic_names), *> crate::prelude::Module<(#(#field_ty_names), *)> for FromTuple { + impl<#(#wrapper_generic_names), *> ::dfdx::prelude::Module<(#(#field_ty_names), *)> for FromTuple { type Output = #wrapper_ident<#(#wrapper_generic_names), *>; - fn try_forward(&self, x: (#(#field_ty_names), *)) -> Result { + fn try_forward(&self, x: (#(#field_ty_names), *)) -> Result { Ok(x.into()) } } #[doc = #doc2] - impl<#(#wrapper_generic_names), *> crate::prelude::Module<#wrapper_ident<#(#wrapper_generic_names), *>> for IntoTuple { + impl<#(#wrapper_generic_names), *> ::dfdx::prelude::Module<#wrapper_ident<#(#wrapper_generic_names), *>> for IntoTuple { type Output = (#(#field_ty_names), *); - fn try_forward(&self, x: #wrapper_ident<#(#wrapper_generic_names), *>) -> Result { + fn try_forward(&self, x: #wrapper_ident<#(#wrapper_generic_names), *>) -> Result { Ok(x.into()) } } @@ -1295,7 +1295,7 @@ pub fn input_wrapper( ); } contains_ident = true; - quote!(>::Output) + quote!(>::Output) } else { quote!(#ty_ident) } @@ -1334,15 +1334,15 @@ pub fn input_wrapper( let field_access_module = quote! { #[doc = #doc] - impl, #(#wrapper_generic_names), *> crate::prelude::Module<#wrapper_ident<#(#wrapper_generic_names), *>> for crate::prelude::On<#on_acccess, M> { + impl, #(#wrapper_generic_names), *> ::dfdx::prelude::Module<#wrapper_ident<#(#wrapper_generic_names), *>> for ::dfdx::prelude::On<#on_acccess, M> { type Output = #wrapper_ident<#(#output_generics), *>; - fn try_forward(&self, x: #wrapper_ident<#(#wrapper_generic_names), *>) -> Result { + fn try_forward(&self, x: #wrapper_ident<#(#wrapper_generic_names), *>) -> Result { #(#field_extraction)* let #forward = self.t.try_forward(#forward)?; let x = #field_replacement; Ok(x) } - fn try_forward_mut(&mut self, x: #wrapper_ident<#(#wrapper_generic_names), *>) -> Result { + fn try_forward_mut(&mut self, x: #wrapper_ident<#(#wrapper_generic_names), *>) -> Result { #(#field_extraction)* let #forward = self.t.try_forward_mut(#forward)?; let x = #field_replacement; diff --git a/dfdx/src/nn/layers/on.rs b/dfdx/src/nn/layers/on.rs index 9d0465cb..6396d70c 100644 --- a/dfdx/src/nn/layers/on.rs +++ b/dfdx/src/nn/layers/on.rs @@ -4,13 +4,12 @@ use std::marker::PhantomData; // TODO: try making a Call module, whih allows calling an arbitrary method on the input. /// Applies module `T` into an input field from a wrapper. -#[derive( - Debug, Default, Clone, ResetParams, ZeroGrads, UpdateParams, LoadSafeTensors, SaveSafeTensors, -)] +#[cfg_attr(feature = "safetensors", derive(SaveSafeTensors, LoadSafeTensors))] +#[derive(Debug, Default, Clone, ResetParams, ZeroGrads, UpdateParams)] #[repr(transparent)] pub struct On { #[module] - #[serialize] + #[cfg_attr(feature = "safetensors", serialize)] pub t: T, pub _n: PhantomData, From c96938a8ce7088fbc74ecd4c45e83ba7e47be1c2 Mon Sep 17 00:00:00 2001 From: Thiago Machado Date: Tue, 12 Dec 2023 08:19:41 -0500 Subject: [PATCH 7/7] rename local variables for input_wrapper --- dfdx-derives/src/lib.rs | 48 ++++++++++++++++++++--------------------- 1 file changed, 24 insertions(+), 24 deletions(-) diff --git a/dfdx-derives/src/lib.rs b/dfdx-derives/src/lib.rs index c0782fdb..6d0edd11 100644 --- a/dfdx-derives/src/lib.rs +++ b/dfdx-derives/src/lib.rs @@ -1109,11 +1109,11 @@ pub fn input_wrapper( let imports = if are_fields_named { quote! { - use super::#wrapper_ident; + use super::*; } } else { quote! { - use super::#wrapper_ident; + use super::*; // TODO: import tuple stuff use crate::prelude; } @@ -1185,11 +1185,11 @@ pub fn input_wrapper( for (i, (_, ident, _, _span)) in wrapper_fields.iter().enumerate() { let i = syn::Index::from(i); if let Some(ident) = ident { - field_from_tuple.push(quote! {#ident: x.#i}); - field_to_tuple.push(quote! {x.#ident}); + field_from_tuple.push(quote! {#ident: __x.#i}); + field_to_tuple.push(quote! {__x.#ident}); } else { - field_from_tuple.push(quote! {x.#i}); - field_to_tuple.push(quote! {x.#i}); + field_from_tuple.push(quote! {__x.#i}); + field_to_tuple.push(quote! {__x.#i}); }; } @@ -1216,13 +1216,13 @@ pub fn input_wrapper( quote! { #[doc = #doc1] impl<#(#wrapper_generic_names), *> From<(#(#field_ty_names), *)> for #wrapper_ident<#(#wrapper_generic_names), *> { - fn from(x: (#(#field_ty_names), *)) -> Self { + fn from(__x: (#(#field_ty_names), *)) -> Self { #from_tuple } } #[doc = #doc2] impl<#(#wrapper_generic_names), *> From<#wrapper_ident<#(#wrapper_generic_names), *>> for (#(#field_ty_names), *) { - fn from(x: #wrapper_ident<#(#wrapper_generic_names), *>) -> Self { + fn from(__x: #wrapper_ident<#(#wrapper_generic_names), *>) -> Self { #to_tuple } } @@ -1237,15 +1237,15 @@ pub fn input_wrapper( #[doc = #doc1] impl<#(#wrapper_generic_names), *> ::dfdx::prelude::Module<(#(#field_ty_names), *)> for FromTuple { type Output = #wrapper_ident<#(#wrapper_generic_names), *>; - fn try_forward(&self, x: (#(#field_ty_names), *)) -> Result { - Ok(x.into()) + fn try_forward(&self, __x: (#(#field_ty_names), *)) -> Result { + Ok(__x.into()) } } #[doc = #doc2] impl<#(#wrapper_generic_names), *> ::dfdx::prelude::Module<#wrapper_ident<#(#wrapper_generic_names), *>> for IntoTuple { type Output = (#(#field_ty_names), *); - fn try_forward(&self, x: #wrapper_ident<#(#wrapper_generic_names), *>) -> Result { - Ok(x.into()) + fn try_forward(&self, __x: #wrapper_ident<#(#wrapper_generic_names), *>) -> Result { + Ok(__x.into()) } } } @@ -1272,7 +1272,7 @@ pub fn input_wrapper( wrapper_ident, ident ); let on_access = ident.clone(); - let forward = syn::Ident::new(&format!("x{i}"), ident.span()); + let forward = syn::Ident::new(&format!("__x{i}"), ident.span()); (doc, on_access, forward) } else { let doc = format!( @@ -1281,7 +1281,7 @@ pub fn input_wrapper( wrapper_ident, ); let on_access = syn::Ident::new(&format!("_{}", i), *span); - let forward = syn::Ident::new(&format!("x{i}"), *span); + let forward = syn::Ident::new(&format!("__x{i}"), *span); (doc, on_access, forward) }; @@ -1307,14 +1307,14 @@ pub fn input_wrapper( for (i, (_, _ident, _, span)) in wrapper_fields.iter().enumerate() { let ii = syn::Index::from(i); if let Some(_ident) = _ident { - let xident = syn::Ident::new(&format!("x{i}"), _ident.span()); + let xident = syn::Ident::new(&format!("__x{i}"), _ident.span()); field_extraction_idents.push(xident.clone()); - field_extraction.push(quote! {let #xident = x.#_ident;}); + field_extraction.push(quote! {let #xident = __x.#_ident;}); field_construction.push(quote! {#_ident: #xident,}); } else { - let xident = syn::Ident::new(&format!("x{i}"), *span); + let xident = syn::Ident::new(&format!("__x{i}"), *span); field_extraction_idents.push(xident.clone()); - field_extraction.push(quote! {let #xident = x.#ii;}); + field_extraction.push(quote! {let #xident = __x.#ii;}); field_construction.push(quote! {#xident,}); }; } @@ -1336,17 +1336,17 @@ pub fn input_wrapper( #[doc = #doc] impl, #(#wrapper_generic_names), *> ::dfdx::prelude::Module<#wrapper_ident<#(#wrapper_generic_names), *>> for ::dfdx::prelude::On<#on_acccess, M> { type Output = #wrapper_ident<#(#output_generics), *>; - fn try_forward(&self, x: #wrapper_ident<#(#wrapper_generic_names), *>) -> Result { + fn try_forward(&self, __x: #wrapper_ident<#(#wrapper_generic_names), *>) -> Result { #(#field_extraction)* let #forward = self.t.try_forward(#forward)?; - let x = #field_replacement; - Ok(x) + let __x = #field_replacement; + Ok(__x) } - fn try_forward_mut(&mut self, x: #wrapper_ident<#(#wrapper_generic_names), *>) -> Result { + fn try_forward_mut(&mut self, __x: #wrapper_ident<#(#wrapper_generic_names), *>) -> Result { #(#field_extraction)* let #forward = self.t.try_forward_mut(#forward)?; - let x = #field_replacement; - Ok(x) + let __x = #field_replacement; + Ok(__x) } } };