Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Adds #[derive(Sequential)] on nn builder structs #803

Closed
wants to merge 1 commit into from

Conversation

coreylowman
Copy link
Owner

Currently only tuples are supported for sequential networks. They are convenient and easy to write for small networks, but have a number of weaknesses when trying to scale up to larger networks.

  1. The types get extremely long, making debug/error messages extremely difficult to read.
  2. They are hard to work with when trying to accessing fields/subfields.

This PR introduces a new api for specifying sequential networks using a derive attribute:

#[derive(Sequential)]
struct MyMLP {
   layer1: Linear<5, 10>,
   act1: ReLU,
   layer2: Linear<10, 5>,
   act2: Tanh,
}

Internally, this will generate a new type that looks like this:

struct MyMLPBuilt<E: Dtype, D: Device<E>> {
   layer1: Linear<5, 10, E, D>,
   act1: ReLU,
   layer2: Linear<10, 5, E, D>,
   act2: Tanh,
}

and add all the TensorCollection/Module/ModuleMut implementations necessary.

This should be much easier to work with!

@coreylowman coreylowman marked this pull request as draft July 3, 2023 14:26
@coreylowman
Copy link
Owner Author

Noting that this effort has been moved into a separate repo https://github.com/coreylowman/slimnn. It's turned into a much bigger rewrite for the nn layer that will add:

  1. #[derive(Sequential)] as mentioned in pr description
  2. Simplify traits for nn layer (& remove TensorCollection)
  3. Add derive macros for other new traits (ZeroGrads/ResetParams/UpdateParams/etc)
  4. Add dynamically sized layers (Dynamic dimensions in neural network layers? #755).

I'm doing this work separately just for ease of implementation, and to see what it's like to implement a nn layer outside of the repo. Once that separate implementation is fairly complete, it will be merged into dfdx.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant