Skip to content

Commit

Permalink
Merge pull request #4168 from google:nnx-improve-messaging
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 671719107
  • Loading branch information
Flax Authors committed Sep 6, 2024
2 parents e848a99 + 0be4e14 commit b494c3f
Show file tree
Hide file tree
Showing 12 changed files with 148 additions and 189 deletions.
8 changes: 0 additions & 8 deletions docs/api_reference/flax.nnx/experimental.rst

This file was deleted.

3 changes: 1 addition & 2 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,5 @@ Notable examples in Flax include:
developer_notes/index
philosophy
contributing
experimental
api_reference/index
NNX <nnx/index>
Flax NNX <nnx/index>
14 changes: 8 additions & 6 deletions docs/nnx/filters_guide.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@
"source": [
"# Using Filters\n",
"\n",
"Filters are used extensively in NNX as a way to create `State` groups in APIs\n",
"such as `nnx.split`, `nnx.state`, and many of the NNX transforms. For example:"
"> **Attention**: This page relates to the new Flax NNX API.\n",
"\n",
"Filters are used extensively in Flax NNX as a way to create `State` groups in APIs\n",
"such as `nnx.split`, `nnx.state`, and many of the Flax NNX transforms. For example:"
]
},
{
Expand Down Expand Up @@ -116,8 +118,8 @@
"metadata": {},
"source": [
"Such function matches any value that is an instance of `Param` or any value that has a \n",
"`type` attribute that is a subclass of `Param`. Internally NNX uses `OfType` which defines \n",
"a callable of this form for a given type:"
"`type` attribute that is a subclass of `Param`. Internally Flax NNX uses `OfType` which\n",
"defines a callable of this form for a given type:"
]
},
{
Expand Down Expand Up @@ -149,11 +151,11 @@
"source": [
"## The Filter DSL\n",
"\n",
"To avoid users having to create these functions, NNX exposes a small DSL, formalized \n",
"To avoid users having to create these functions, Flax NNX exposes a small DSL, formalized \n",
"as the `nnx.filterlib.Filter` type, which lets users pass types, booleans, ellipsis, \n",
"tuples/lists, etc, and converts them to the appropriate predicate internally.\n",
"\n",
"Here is a list of all the callable Filters included in NNX and their DSL literals \n",
"Here is a list of all the callable Filters included in Flax NNX and their DSL literals\n",
"(when available):\n",
"\n",
"\n",
Expand Down
14 changes: 8 additions & 6 deletions docs/nnx/filters_guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,10 @@ jupytext:

# Using Filters

Filters are used extensively in NNX as a way to create `State` groups in APIs
such as `nnx.split`, `nnx.state`, and many of the NNX transforms. For example:
> **Attention**: This page relates to the new Flax NNX API.
Filters are used extensively in Flax NNX as a way to create `State` groups in APIs
such as `nnx.split`, `nnx.state`, and many of the Flax NNX transforms. For example:

```{code-cell} ipython3
from flax import nnx
Expand Down Expand Up @@ -63,8 +65,8 @@ print(f'{is_param((), nnx.VariableState(type=nnx.Param, value=0)) = }')
```

Such function matches any value that is an instance of `Param` or any value that has a
`type` attribute that is a subclass of `Param`. Internally NNX uses `OfType` which defines
a callable of this form for a given type:
`type` attribute that is a subclass of `Param`. Internally Flax NNX uses `OfType` which
defines a callable of this form for a given type:

```{code-cell} ipython3
is_param = nnx.OfType(nnx.Param)
Expand All @@ -75,11 +77,11 @@ print(f'{is_param((), nnx.VariableState(type=nnx.Param, value=0)) = }')

## The Filter DSL

To avoid users having to create these functions, NNX exposes a small DSL, formalized
To avoid users having to create these functions, Flax NNX exposes a small DSL, formalized
as the `nnx.filterlib.Filter` type, which lets users pass types, booleans, ellipsis,
tuples/lists, etc, and converts them to the appropriate predicate internally.

Here is a list of all the callable Filters included in NNX and their DSL literals
Here is a list of all the callable Filters included in Flax NNX and their DSL literals
(when available):


Expand Down
37 changes: 21 additions & 16 deletions docs/nnx/index.rst
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@

NNX
Flax NNX
========
.. div:: sd-text-left sd-font-italic

Expand All @@ -8,11 +8,15 @@ NNX

----

NNX is a new Flax API that is designed to make it easier to create, inspect, debug,
and analyze neural networks in JAX. It achieves this by adding first class support
Flax NNX is a new simplified API that is designed to make it easier to create, inspect,
debug, and analyze neural networks in JAX. It achieves this by adding first class support
for Python reference semantics, allowing users to express their models using regular
Python objects. NNX takes years of feedback from Linen and brings to Flax a simpler
and more user-friendly experience.
Python objects. Flax NNX is an evolution of the previous Flax Linen APIs, it takes years of
experience to bring a simpler and more user-friendly experience.

.. note::
Flax Linen is not going to be deprecated in the near future as most of our users still
rely on this API, however new users are encouraged to use Flax NNX.

Features
^^^^^^^^^
Expand All @@ -29,7 +33,7 @@ Features

.. div:: sd-font-normal

NNX supports the use of regular Python objects, providing an intuitive
Flax NNX supports the use of regular Python objects, providing an intuitive
and predictable development experience.

.. grid-item::
Expand All @@ -42,33 +46,34 @@ Features

.. div:: sd-font-normal

NNX relies on Python's object model, which results in simplicity for
Flax NNX relies on Python's object model, which results in simplicity for
the user and increases development speed.

.. grid-item::
:columns: 12 12 12 6

.. card:: Streamlined
.. card:: Expressive
:class-card: sd-border-0
:shadow: none
:class-title: sd-fs-5

.. div:: sd-font-normal

NNX integrates user feedback and hands-on experience with Linen
into a new simplified API.
Flax NNX allows fine-grained control of the model's state via
its `Filter <https://flax.readthedocs.io/en/latest/nnx/filters_guide.html>`__
system.

.. grid-item::
:columns: 12 12 12 6

.. card:: Compatible
.. card:: Familiar
:class-card: sd-border-0
:shadow: none
:class-title: sd-fs-5

.. div:: sd-font-normal

NNX makes it very easy to integrate objects with regular JAX code
Flax NNX makes it very easy to integrate objects with regular JAX code
via the `Functional API <nnx_basics.html#the-functional-api>`__.

Basic usage
Expand Down Expand Up @@ -114,7 +119,7 @@ Basic usage
Installation
^^^^^^^^^^^^

Install NNX via pip:
Install via pip:

.. code-block:: bash
Expand All @@ -137,7 +142,7 @@ Learn more
.. grid-item::
:columns: 6 6 6 4

.. card:: :material-regular:`rocket_launch;2em` NNX Basics
.. card:: :material-regular:`rocket_launch;2em` Flax NNX Basics
:class-card: sd-text-black sd-bg-light
:link: nnx_basics.html

Expand All @@ -151,14 +156,14 @@ Learn more
.. grid-item::
:columns: 6 6 6 4

.. card:: :material-regular:`sync_alt;2em` NNX vs JAX Transformations
.. card:: :material-regular:`sync_alt;2em` Flax vs JAX Transformations
:class-card: sd-text-black sd-bg-light
:link: transforms.html

.. grid-item::
:columns: 6 6 6 4

.. card:: :material-regular:`transform;2em` Haiku and Linen vs NNX
.. card:: :material-regular:`transform;2em` Haiku and Flax Linen vs Flax NNX
:class-card: sd-text-black sd-bg-light
:link: haiku_linen_vs_nnx.html

Expand Down
23 changes: 12 additions & 11 deletions docs/nnx/mnist_tutorial.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
"\n",
"# MNIST Tutorial\n",
"\n",
"Welcome to NNX! This tutorial will guide you through building and training a simple convolutional \n",
"neural network (CNN) on the MNIST dataset using the NNX API. NNX is a Python neural network library\n",
"Welcome to Flax NNX! This tutorial will guide you through building and training a simple convolutional \n",
"neural network (CNN) on the MNIST dataset using the Flax NNX API. Flax NNX is a Python neural network library\n",
"built upon [JAX](https://github.com/google/jax) and currently offered as an experimental module within \n",
"[Flax](https://github.com/google/flax)."
]
Expand All @@ -21,9 +21,10 @@
"id": "1",
"metadata": {},
"source": [
"## 1. Install NNX\n",
"## 1. Install Flax\n",
"\n",
"Since NNX is under active development, we recommend using the latest version from the Flax GitHub repository:"
"If `flax` is not installed in your environment, you can install it from PyPI, uncomment and run the \n",
"following cell:"
]
},
{
Expand All @@ -37,7 +38,7 @@
},
"outputs": [],
"source": [
"# !pip install git+https://github.com/google/flax.git"
"# !pip install flax"
]
},
{
Expand Down Expand Up @@ -109,9 +110,9 @@
"id": "5",
"metadata": {},
"source": [
"## 3. Define the Network with NNX\n",
"## 3. Define the Network with Flax NNX\n",
"\n",
"Create a convolutional neural network with NNX by subclassing `nnx.Module`."
"Create a convolutional neural network with Flax NNX by subclassing `nnx.Module`."
]
},
{
Expand All @@ -134,7 +135,7 @@
}
],
"source": [
"from flax import nnx # NNX API\n",
"from flax import nnx # Flax NNX API\n",
"from functools import partial\n",
"\n",
"class CNN(nnx.Module):\n",
Expand Down Expand Up @@ -204,7 +205,7 @@
"source": [
"## 4. Create Optimizer and Metrics\n",
"\n",
"In NNX, we create an `Optimizer` object to manage the model's parameters and apply gradients during training. `Optimizer` receives the model's reference so it can update its parameters, and an `optax` optimizer to define the update rules. Additionally, we'll define a `MultiMetric` object to keep track of the `Accuracy` and the `Average` loss."
"In Flax NNX, we create an `Optimizer` object to manage the model's parameters and apply gradients during training. `Optimizer` receives the model's reference so it can update its parameters, and an `optax` optimizer to define the update rules. Additionally, we'll define a `MultiMetric` object to keep track of the `Accuracy` and the `Average` loss."
]
},
{
Expand Down Expand Up @@ -287,9 +288,9 @@
"The [`nnx.jit`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.jit) decorator traces the `train_step` function for just-in-time compilation with \n",
"[XLA](https://www.tensorflow.org/xla), optimizing performance on \n",
"hardware accelerators. `nnx.jit` is similar to [`jax.jit`](https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html#jax.jit),\n",
"except it can transforms functions that contain NNX objects as inputs and outputs.\n",
"except it can transforms functions that contain Flax NNX objects as inputs and outputs.\n",
"\n",
"**NOTE**: in the above code we performed serveral inplace updates to the model, optimizer, and metrics, and we did not explicitely return the state updates. This is because NNX transforms respect reference semantics for NNX objects, and will propagate the state updates of the objects passed as input arguments. This is a key feature of NNX that allows for a more concise and readable code."
"**NOTE**: in the above code we performed serveral inplace updates to the model, optimizer, and metrics, and we did not explicitely return the state updates. This is because Flax NNX transforms respect reference semantics for Flax NNX objects, and will propagate the state updates of the objects passed as input arguments. This is a key feature of Flax NNX that allows for a more concise and readable code."
]
},
{
Expand Down
23 changes: 12 additions & 11 deletions docs/nnx/mnist_tutorial.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,21 +14,22 @@ jupytext:

# MNIST Tutorial

Welcome to NNX! This tutorial will guide you through building and training a simple convolutional
neural network (CNN) on the MNIST dataset using the NNX API. NNX is a Python neural network library
Welcome to Flax NNX! This tutorial will guide you through building and training a simple convolutional
neural network (CNN) on the MNIST dataset using the Flax NNX API. Flax NNX is a Python neural network library
built upon [JAX](https://github.com/google/jax) and currently offered as an experimental module within
[Flax](https://github.com/google/flax).

+++

## 1. Install NNX
## 1. Install Flax

Since NNX is under active development, we recommend using the latest version from the Flax GitHub repository:
If `flax` is not installed in your environment, you can install it from PyPI, uncomment and run the
following cell:

```{code-cell} ipython3
:tags: [skip-execution]
# !pip install git+https://github.com/google/flax.git
# !pip install flax
```

## 2. Load the MNIST Dataset
Expand Down Expand Up @@ -71,12 +72,12 @@ train_ds = train_ds.batch(batch_size, drop_remainder=True).take(train_steps).pre
test_ds = test_ds.batch(batch_size, drop_remainder=True).prefetch(1)
```

## 3. Define the Network with NNX
## 3. Define the Network with Flax NNX

Create a convolutional neural network with NNX by subclassing `nnx.Module`.
Create a convolutional neural network with Flax NNX by subclassing `nnx.Module`.

```{code-cell} ipython3
from flax import nnx # NNX API
from flax import nnx # Flax NNX API
from functools import partial
class CNN(nnx.Module):
Expand Down Expand Up @@ -116,7 +117,7 @@ nnx.display(y)

## 4. Create Optimizer and Metrics

In NNX, we create an `Optimizer` object to manage the model's parameters and apply gradients during training. `Optimizer` receives the model's reference so it can update its parameters, and an `optax` optimizer to define the update rules. Additionally, we'll define a `MultiMetric` object to keep track of the `Accuracy` and the `Average` loss.
In Flax NNX, we create an `Optimizer` object to manage the model's parameters and apply gradients during training. `Optimizer` receives the model's reference so it can update its parameters, and an `optax` optimizer to define the update rules. Additionally, we'll define a `MultiMetric` object to keep track of the `Accuracy` and the `Average` loss.

```{code-cell} ipython3
import optax
Expand Down Expand Up @@ -162,9 +163,9 @@ def eval_step(model: CNN, metrics: nnx.MultiMetric, batch):
The [`nnx.jit`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.jit) decorator traces the `train_step` function for just-in-time compilation with
[XLA](https://www.tensorflow.org/xla), optimizing performance on
hardware accelerators. `nnx.jit` is similar to [`jax.jit`](https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html#jax.jit),
except it can transforms functions that contain NNX objects as inputs and outputs.
except it can transforms functions that contain Flax NNX objects as inputs and outputs.

**NOTE**: in the above code we performed serveral inplace updates to the model, optimizer, and metrics, and we did not explicitely return the state updates. This is because NNX transforms respect reference semantics for NNX objects, and will propagate the state updates of the objects passed as input arguments. This is a key feature of NNX that allows for a more concise and readable code.
**NOTE**: in the above code we performed serveral inplace updates to the model, optimizer, and metrics, and we did not explicitely return the state updates. This is because Flax NNX transforms respect reference semantics for Flax NNX objects, and will propagate the state updates of the objects passed as input arguments. This is a key feature of Flax NNX that allows for a more concise and readable code.

+++

Expand Down
Loading

0 comments on commit b494c3f

Please sign in to comment.