Skip to content

Commit

Permalink
Documneting CustomModule
Browse files Browse the repository at this point in the history
  • Loading branch information
coreylowman committed Aug 30, 2023
1 parent 275afae commit 1e61b9a
Showing 1 changed file with 75 additions and 0 deletions.
75 changes: 75 additions & 0 deletions dfdx-nn-derives/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,81 @@ macro_rules! has_attr {
};
}

/// Allows you to implement [dfdx_nn::Module], while automatically implementing the following:
/// 1. [dfdx_nn::BuildOnDevice]
/// 2. [dfdx_nn::ResetParams]
/// 3. [dfdx_nn::UpdateParams]
/// 4. [dfdx_nn::ZeroGrads]
/// 5. [dfdx_nn::SaveSafeTensors]
/// 6. [dfdx_nn::LoadSafeTensors]
///
/// If your struct contains sub module configs, then you must add the `#[module]` attribute to those items. Any field that is marked with `#[module]` will be expected to implement [dfdx_nn::BuildOnDevice].
///
/// You can control the name of the built struct with the `#[built(<type name>)]` attribute on the struct.
///
/// # Using CustomModule on unit structs
///
/// Here we have a unit struct that just calls a method on Tensor in the forward:
///
/// ```rust
/// # use dfdx_nn::*;
/// # use dfdx::prelude::*;
/// #[derive(Default, Debug, Clone, Copy, CustomModule)]
/// pub struct Abs;
/// impl<S: Shape, E: Dtype, D: Device<E>, T: Tape<E, D>> Module<Tensor<S, E, D, T>> for Abs {
/// type Output = Tensor<S, E, D, T>;
/// type Error = D::Err;
/// fn try_forward(&self, x: Tensor<S, E, D, T>) -> Result<Self::Output, Self::Error> {
/// x.try_abs()
/// }
/// }
/// ```
///
/// # Using CustomModule on structs with non-parameter fields
///
/// ```rust
/// # use dfdx_nn::*;
/// # use dfdx::prelude::*;
/// #[derive(Default, Debug, Clone, Copy, CustomModule)]
/// pub struct Reshape<S: Shape>(pub S);
///
/// impl<Src: Shape, Dst: Shape, E: Dtype, D: Device<E>, T: Tape<E, D>> Module<Tensor<Src, E, D, T>>
/// for Reshape<Dst>
/// {
/// type Output = Tensor<Dst, E, D, T>;
/// type Error = D::Err;
/// fn try_forward(&self, x: Tensor<Src, E, D, T>) -> Result<Self::Output, Self::Error> {
/// x.try_reshape_like(&self.0)
/// }
/// }
/// ```
///
/// # Using CustomModule on structs with sub modules
///
/// Here there are a couple things to note:
/// 1. We must use the `#[built(<type name>)]` to control the name of the actual built struct
/// 2. We must use that type name when implementing `Module`
/// 3. We must annotate the sub module with `#[module]`
///
/// ```rust
/// # use dfdx_nn::*;
/// # use dfdx::prelude::*;
/// #[derive(Debug, Clone, CustomModule)]
/// #[built(ResidualMatMul)]
/// pub struct ResidualMatMulConfig<I: Dim, O: Dim>(#[module] pub matmul: MatMulConfig<I, O>);
///
/// impl<X: WithEmptyTape, I: Dim, O: Dim, E: Dtype, D: Device<E>> Module<X> for ResidualMatMul<I, O, E, D>
/// where
/// MatMul<I, O, E, D>: Module<X, Output = X>,
/// X: TryAdd<X, Output = X>,
/// {
/// type Output = X;
/// type Error = D::Err;
/// fn try_forward(&self, x: X) -> Result<Self::Output, Self::Error> {
/// self.matmul.try_forward(x.with_empty_tape())?.try_add(x)
/// }
/// }
/// ```
#[proc_macro_derive(CustomModule, attributes(module, built))]
pub fn custom_module(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
let input = parse_macro_input!(input as DeriveInput);
Expand Down

0 comments on commit 1e61b9a

Please sign in to comment.