Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add trimmed Linen to NNX guide #4209

Merged
merged 1 commit into from
Sep 21, 2024
Merged

Conversation

IvyZX
Copy link
Collaborator

@IvyZX IvyZX commented Sep 19, 2024

Extracted the Linen/NNX part from the original Haiku/Linen/NNX guide, removed a bunch of Haiku-targeted examples, and added a few structured explanations and usage examples targeting existing Linen users, including:

  • Explain the fundamentals - stateful vs stateless, lazy vs eager
  • RNGs
  • JAX-style transforms, like scan over layers

The original Haiku/Linen/NNX guide will later be tailored to a Haiku/Flax NNX guide.


* **Lazy vs. eager**: Linen modules only allocate space to create variables when they actually see their input. Whereas NNX module instances create their variables the moment they are instantiated, without seeing a sample input.

* This means Linen can ``@nn.compact`` decorator to define a model with only one method, wheras NNX modules must have both ``__init__`` and ``__call__`` defined. This also means that the input shape must be explicitly passed during module creation because the parameter shapes cannot be inferred from the input.
Copy link
Collaborator

@cgarciae cgarciae Sep 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some rewording to improve accuracy.

Suggested change
* This means Linen can ``@nn.compact`` decorator to define a model with only one method, wheras NNX modules must have both ``__init__`` and ``__call__`` defined. This also means that the input shape must be explicitly passed during module creation because the parameter shapes cannot be inferred from the input.
* Linen can use the ``@nn.compact`` decorator to define the model in a single method and use shape inference from the input sample, whereas NNX modules generally requests additional shape information to create all parameters during ``__init__`` and separately define the computation in ``__call__``.


* Linen uses ``@jax.jit`` to compile the training step, whereas NNX uses ``@nnx.jit``. ``jax.jit`` only accepts pure stateless arguments, but ``nnx.jit`` allows the arguments to be stateful NNX modules. This greatly reduced the number of lines needed for a train step.

* Similarly, Linen uses ``jax.grad()`` to return a raw dictionary of gradients, wheras NNX can use ``nnx.grad`` to return an NNX model state of gradients. In NNX, here you need to use the split/merge API
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
* Similarly, Linen uses ``jax.grad()`` to return a raw dictionary of gradients, wheras NNX can use ``nnx.grad`` to return an NNX model state of gradients. In NNX, here you need to use the split/merge API
* Similarly, Linen uses ``jax.grad()`` to return a raw dictionary of gradients, wheras NNX can use ``nnx.grad`` to return the gradients of Modules as NNX ``State`` dictionaries. To use regular ``jax.grad`` with NNX you need to use the split/merge API


* Similarly, Linen uses ``jax.grad()`` to return a raw dictionary of gradients, wheras NNX can use ``nnx.grad`` to return an NNX model state of gradients. In NNX, here you need to use the split/merge API

* If you are already using Optax optimizer classes like ``optax.adam(1e-3)`` and use its ``update()`` method to update your model params (instead of the raw ``jax.tree.map`` computation like here), check out `nnx.Optimizer example <https://flax-nnx.readthedocs.io/en/latest/nnx_basics.html#transforms>`__ for a much more concise way of training and updating your model.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
* If you are already using Optax optimizer classes like ``optax.adam(1e-3)`` and use its ``update()`` method to update your model params (instead of the raw ``jax.tree.map`` computation like here), check out `nnx.Optimizer example <https://flax-nnx.readthedocs.io/en/latest/nnx_basics.html#transforms>`__ for a much more concise way of training and updating your model.
* If you are already using Optax optimizers like ``optax.adamw`` (instead of the raw ``jax.tree.map`` computation shown here), check out `nnx.Optimizer example <https://flax-nnx.readthedocs.io/en/latest/nnx_basics.html#transforms>`__ for a much more concise way of training and updating your model.

Comment on lines 161 to 167
def loss_fn(model):
logits = model(

inputs, # <== inputs

)
return optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could be more appealing to show it like this to contrast how much simpler NNX is in this case.

Suggested change
def loss_fn(model):
logits = model(
inputs, # <== inputs
)
return optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean()
def loss_fn(model):
logits = model(inputs)
return optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean()

Comment on lines 195 to 197
* ``nn.BatchNorm`` creates ``batch_stats`` -> ``nnx.BatchNorm`` creates ``nnx.BatchStats``.

* ``linen.Module.sow()`` creates ``intermediates`` -> ``nnx.Module.sow()`` creates ``nnx.Intermediates``. You can also simply get the intermediates by assigning it to a module attribute, like ``self.sowed = nnx.Intermediates(x)``.
Copy link
Collaborator

@cgarciae cgarciae Sep 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
* ``nn.BatchNorm`` creates ``batch_stats`` -> ``nnx.BatchNorm`` creates ``nnx.BatchStats``.
* ``linen.Module.sow()`` creates ``intermediates`` -> ``nnx.Module.sow()`` creates ``nnx.Intermediates``. You can also simply get the intermediates by assigning it to a module attribute, like ``self.sowed = nnx.Intermediates(x)``.
* ``nn.Dense`` creates ``params`` -> ``nnx.Linear`` creates ``nnx.Param``.
* ``nn.BatchNorm`` creates ``batch_stats`` -> ``nnx.BatchNorm`` creates ``nnx.BatchStat``.
* ``linen.Module.sow()`` creates ``intermediates`` -> ``nnx.Module.sow()`` creates ``nnx.Intermediate``. You can also simply get the intermediates by assigning it to a module attribute, like ``self.sowed = nnx.Intermediate(x)``.

BTW: in Linen this is also true, you can simply add an intermediate via self.variable('intermediates' 'sowed', lambda: x)

Comment on lines +210 to +216

@nn.compact
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add new line to align __call__s

Suggested change
@nn.compact
@nn.compact

Comment on lines 214 to 221
x = self.batchnorm(x, use_running_average=not training)
x = jax.nn.relu(x)
Copy link
Collaborator

@cgarciae cgarciae Sep 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Increment count

Suggested change
x = self.batchnorm(x, use_running_average=not training)
x = jax.nn.relu(x)
x = self.batchnorm(x, use_running_average=not training)
self.count.value += 1
x = jax.nn.relu(x)

Using Multiple Methods
==========

In this section we will take a look at how to use multiple methods in all three
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
In this section we will take a look at how to use multiple methods in all three
In this section we will take a look at how to use multiple methods in both

Comment on lines 454 to 455
In NNX, we define a scan function ``scan_fn`` that will use the ``RNNCell`` defined
in ``__init__`` to scan over the sequence.
Copy link
Collaborator

@cgarciae cgarciae Sep 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add some explanation for nnx.scan

Suggested change
In NNX, we define a scan function ``scan_fn`` that will use the ``RNNCell`` defined
in ``__init__`` to scan over the sequence.
In NNX, we define a scan function ``scan_fn`` that will use the ``RNNCell`` defined
in ``__init__`` to scan over the sequence, and explicitly set ``in_axes=(nnx.Carry, None, 1)``,
``Carry`` means that the ``carry`` argument will be the carry, ``None`` means that ``cell`` will
be broadcasted to all steps, and ``1`` means ``x`` will be scanned across axis 1.

Scan over Layers
==========

In general, lifted transforms of Linen and NNX should look the same. However, NNX lifted transforms is designed to be closer to their lower level JAX counterparts, and thus we throw away some assumptions in certain Linen lifted transforms. This scan-over-layers use case will be a good example to showcase it.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
In general, lifted transforms of Linen and NNX should look the same. However, NNX lifted transforms is designed to be closer to their lower level JAX counterparts, and thus we throw away some assumptions in certain Linen lifted transforms. This scan-over-layers use case will be a good example to showcase it.
In general, lifted transforms of Linen and NNX should look the same. However, NNX lifted transforms are designed to be closer to their lower level JAX counterparts, and thus we throw away some assumptions in certain Linen lifted transforms. This scan-over-layers use case will be a good example to showcase it.


In Linen, we apply a ``nn.scan`` upon the module ``Block`` to create a larger module ``ScanBlock`` that contains 5 ``Block``. It will automatically create a large parameter of shape ``(5, 64, 64)`` at initialization time, and at call time iterate over every ``(64, 64)`` slice for a total of 5 times, like a ``jax.lax.scan`` would.

But if you think closely, there actually wans't any ``jax.lax.scan`` operation at initialization time. What happened there is more like a ``jax.vmap`` operation - you are given a ``Block`` that accepts ``(in_dim, out_dim)``, and you "vmap" it over ``num_layers`` of times to create a larger array.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
But if you think closely, there actually wans't any ``jax.lax.scan`` operation at initialization time. What happened there is more like a ``jax.vmap`` operation - you are given a ``Block`` that accepts ``(in_dim, out_dim)``, and you "vmap" it over ``num_layers`` of times to create a larger array.
But if you think closely, there actually isn't the need for ``jax.lax.scan`` operation at initialization time. What happened there is more like a ``jax.vmap`` operation - you are given a ``Block`` that accepts ``(in_dim, out_dim)``, and you "vmap" it over ``num_layers`` of times to create a larger array.


But if you think closely, there actually wans't any ``jax.lax.scan`` operation at initialization time. What happened there is more like a ``jax.vmap`` operation - you are given a ``Block`` that accepts ``(in_dim, out_dim)``, and you "vmap" it over ``num_layers`` of times to create a larger array.

This means that in NNX, since model initialization and running code are completely decoupled, we need to use ``nnx.vmap`` to initialize the underlying blocks, and then use ``nnx.scan`` to run the model input through them.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
This means that in NNX, since model initialization and running code are completely decoupled, we need to use ``nnx.vmap`` to initialize the underlying blocks, and then use ``nnx.scan`` to run the model input through them.
In NNX we take advantage of the fact that model initialization and running code are completely decoupled, and instead use ``nnx.vmap`` to initialize the underlying blocks, and ``nnx.scan`` to run the model input through them.


There are a few other details to explain in this example:

* **What is that `nnx.split_rngs` decorator?** This is because ``jax.vmap`` and ``jax.lax.scan`` requires a list of RNG keys if each of its internal operations needs its own key. So for the 5 layers inside ``MLP``, it will split and provide 5 different RNG keys from its arguments before going into the JAX transform.
Copy link
Collaborator

@cgarciae cgarciae Sep 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added explanation of why nnx.split_rngs is necessary.

Suggested change
* **What is that `nnx.split_rngs` decorator?** This is because ``jax.vmap`` and ``jax.lax.scan`` requires a list of RNG keys if each of its internal operations needs its own key. So for the 5 layers inside ``MLP``, it will split and provide 5 different RNG keys from its arguments before going into the JAX transform.
* **What is that ``nnx.split_rngs`` decorator?** NNX transforms are completely agnostic of RNG state, this makes them behave more like JAX transforms but diverge from the Linen transforms which do handle RNG state. To regain this functionality, the ``nnx.split_rngs`` decorator allows you to split the ``Rngs`` before passing them
to the decorated function and 'lower' them afterwards so they can be used outside.
* This is needed because ``jax.vmap`` and ``jax.lax.scan`` requires a list of RNG keys if each of its internal operations needs its own key. So for the 5 layers inside ``MLP``, it will split and provide 5 different RNG keys from its arguments before going into the JAX transform.

@IvyZX
Copy link
Collaborator Author

IvyZX commented Sep 19, 2024

All comments adopted - thanks @cgarciae !

@copybara-service copybara-service bot merged commit 268366f into google:main Sep 21, 2024
17 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants