Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Breaking] Rewrite of nn to enable runtime layer sizes, proc macro declarations, and more #854

Merged
merged 62 commits into from
Oct 25, 2023

Conversation

coreylowman
Copy link
Owner

@coreylowman coreylowman commented Aug 18, 2023

Summary

This PR is a rewrite of the nn layer to make a few things possible:

  1. Create networks that have both compile time known shapes and runtime known shapes.
  2. Create networks that are structs, instead of tuples. This makes error messages easier to read, and fields easier to access.

Here's an example of both of these in action:

#[derive(Default, Clone, Sequential)]
#[built(Mlp)]
pub struct MlpConfig {
    pub l1: LinearConfig<Const<3>, usize>,
    pub act1: ReLU,
    pub l2: LinearConfig<usize, Const<10>>,
    pub act2: ReLU,
}

Here we define the MLP to have the input and output sizes known at compile time, but the interior hidden dimension is known at runtime. Since the struct has #[derive(Sequential)], the layers are executed in order of declaration.

Also notice the #[built(Mlp)] which indicates the name of the new type that is defined alongside this struct (which contains the actual modules).

As far as instantiating this object, it's also pretty straightforward:

// NOTE: if this was all compile time, we could just do `Default::default()`
let structure = MlpConfig {
    l1: LinearConfig::new(Const, 5),
    act1: Default::default(),
    l2: LinearConfig::new(5, Const),
    act2: Default::default(),
};
let module: Mlp<f32, Cpu> = dev.build_module_ext::<f32>(structure);

Note that you actually have to instantiate the architecture now as an object, instead of it being at type. This is to support runtime values.

Breaking changes

  1. dfdx has been renamed to dfdx-core and doesn't include nn items. dfdx now contains the new nn items, and re-exports everything from dfdx-core.
  2. dfdx::optim has been moved under dfdx::nn::optim
  3. EMA functionality is removed (can be added back in the future, will require more proc macros)
  4. TensorCollection removed
  5. Saving nn layers to npy is removed, now only safetensors is supported.
  6. The old builders structs are now structs with "Config" postfixed (e.g. LinearConfig instead of builders::Linear)
  7. to_device functionality removed
  8. to_dtype functionality removed
  9. UnbiasedLinear renamed to MatMul
  10. GeneralizedResidual renamed to GeneralizedAdd

@coreylowman coreylowman changed the title [Breaking] [WIP] Rewrite of nn to enable runtime layer sizes, proc macro declarations, and more [Breaking] Rewrite of nn to enable runtime layer sizes, proc macro declarations, and more Oct 25, 2023
@coreylowman coreylowman merged commit 5e0c3dd into main Oct 25, 2023
8 checks passed
@coreylowman coreylowman deleted the nn-rewrite branch October 25, 2023 15:14
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
2 participants