diff --git a/dfdx-nn-derives/src/lib.rs b/dfdx-nn-derives/src/lib.rs index b30e4833..effe281b 100644 --- a/dfdx-nn-derives/src/lib.rs +++ b/dfdx-nn-derives/src/lib.rs @@ -8,6 +8,81 @@ macro_rules! has_attr { }; } +/// Allows you to implement [dfdx_nn::Module], while automatically implementing the following: +/// 1. [dfdx_nn::BuildOnDevice] +/// 2. [dfdx_nn::ResetParams] +/// 3. [dfdx_nn::UpdateParams] +/// 4. [dfdx_nn::ZeroGrads] +/// 5. [dfdx_nn::SaveSafeTensors] +/// 6. [dfdx_nn::LoadSafeTensors] +/// +/// If your struct contains sub module configs, then you must add the `#[module]` attribute to those items. Any field that is marked with `#[module]` will be expected to implement [dfdx_nn::BuildOnDevice]. +/// +/// You can control the name of the built struct with the `#[built()]` attribute on the struct. +/// +/// # Using CustomModule on unit structs +/// +/// Here we have a unit struct that just calls a method on Tensor in the forward: +/// +/// ```rust +/// # use dfdx_nn::*; +/// # use dfdx::prelude::*; +/// #[derive(Default, Debug, Clone, Copy, CustomModule)] +/// pub struct Abs; +/// impl, T: Tape> Module> for Abs { +/// type Output = Tensor; +/// type Error = D::Err; +/// fn try_forward(&self, x: Tensor) -> Result { +/// x.try_abs() +/// } +/// } +/// ``` +/// +/// # Using CustomModule on structs with non-parameter fields +/// +/// ```rust +/// # use dfdx_nn::*; +/// # use dfdx::prelude::*; +/// #[derive(Default, Debug, Clone, Copy, CustomModule)] +/// pub struct Reshape(pub S); +/// +/// impl, T: Tape> Module> +/// for Reshape +/// { +/// type Output = Tensor; +/// type Error = D::Err; +/// fn try_forward(&self, x: Tensor) -> Result { +/// x.try_reshape_like(&self.0) +/// } +/// } +/// ``` +/// +/// # Using CustomModule on structs with sub modules +/// +/// Here there are a couple things to note: +/// 1. We must use the `#[built()]` to control the name of the actual built struct +/// 2. We must use that type name when implementing `Module` +/// 3. We must annotate the sub module with `#[module]` +/// +/// ```rust +/// # use dfdx_nn::*; +/// # use dfdx::prelude::*; +/// #[derive(Debug, Clone, CustomModule)] +/// #[built(ResidualMatMul)] +/// pub struct ResidualMatMulConfig(#[module] pub matmul: MatMulConfig); +/// +/// impl> Module for ResidualMatMul +/// where +/// MatMul: Module, +/// X: TryAdd, +/// { +/// type Output = X; +/// type Error = D::Err; +/// fn try_forward(&self, x: X) -> Result { +/// self.matmul.try_forward(x.with_empty_tape())?.try_add(x) +/// } +/// } +/// ``` #[proc_macro_derive(CustomModule, attributes(module, built))] pub fn custom_module(input: proc_macro::TokenStream) -> proc_macro::TokenStream { let input = parse_macro_input!(input as DeriveInput);