Skip to content

How to use shape-dependent variables without @nn.compact in JAX/Flax? #928

Answered by marcvanzee
marcvanzee asked this question in Q&A
Discussion options

You must be logged in to vote

When using setup() In non-compact Module in Flax, one does not have access to shape information from the input directly. In order to use shape dependent variables in those modules, we either pass in the necessary shape information explicitly as construction args, or we isolate any shape-inferred variables in a submodule that we construct from setup().

We do this in the VAE example and the WMT example. In both examples search for setup( to see how it is done.

Replies: 1 comment

Comment options

You must be logged in to vote
0 replies
Answer selected by marcvanzee
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
1 participant
Converted from issue

This discussion was converted from issue #881 on January 22, 2021 13:32.