Skip to content

Commit

Permalink
[nnx] disallow Array leaves
Browse files Browse the repository at this point in the history
Numpy and JAX Array's are no longer consider state leaves. This makes the structure of the State completely determined by Variables, which apart from being more predictable it produces structural stability invariant to leaf type changes which let to issues such as #4142.

```python
class Foo(nnx.Module):
  def __init__(self):
    self.a = jnp.array(1) # no longer allowed, instead...
    self.b = nnx.Param(jnp.array(1)) # just use Variables
```

Also migrates all remaining tests from pytest to absl to ensure they are tested correctly internally.

PiperOrigin-RevId: 671372717
  • Loading branch information
Cristian Garcia authored and Flax Authors committed Sep 5, 2024
1 parent aded9ac commit 90715be
Show file tree
Hide file tree
Showing 10 changed files with 240 additions and 232 deletions.
14 changes: 10 additions & 4 deletions flax/nnx/nnx/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,17 +50,18 @@
Leaf = tp.TypeVar('Leaf')
AuxData = tp.TypeVar('AuxData')

StateLeaf = tp.Union[VariableState[tp.Any], np.ndarray, jax.Array]
StateLeaf = VariableState[tp.Any]
NodeLeaf = VariableState[tp.Any]
GraphState = State[Key, StateLeaf]
GraphFlatState = FlatState[StateLeaf]


def is_state_leaf(x: tp.Any) -> tpe.TypeGuard[StateLeaf]:
return isinstance(x, (VariableState, np.ndarray, jax.Array))
return isinstance(x, VariableState)


def is_node_leaf(x: tp.Any) -> tpe.TypeGuard[StateLeaf]:
return isinstance(x, (Variable, np.ndarray, jax.Array))
def is_node_leaf(x: tp.Any) -> tpe.TypeGuard[NodeLeaf]:
return isinstance(x, Variable)


class _HashById(tp.Hashable, tp.Generic[A]):
Expand Down Expand Up @@ -416,6 +417,11 @@ def _graph_flatten(
flat_state[(*path, key)] = value
leaves.append((key, None))
else:
if isinstance(value, (jax.Array, np.ndarray)):
path_str = '/'.join(map(str, (*path, key)))
raise ValueError(
f'Arrays leaves are not supported, at {path_str!r}: {value}'
)
static_fields.append((key, value))

nodedef = NodeDef.create(
Expand Down
6 changes: 3 additions & 3 deletions flax/nnx/nnx/training/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def __init__(
self.step = OptState(jnp.array(0, dtype=jnp.uint32))
self.model = model
self.tx = tx
self.opt_state = tx.init(nnx.state(model, wrt))
self.opt_state = OptState(tx.init(nnx.state(model, wrt)))
self.wrt = wrt

def split(self, *filters: filterlib.Filter):
Expand Down Expand Up @@ -198,10 +198,10 @@ def update(self, grads):
"""
state = nnx.state(self.model, self.wrt)

updates, new_opt_state = self.tx.update(grads, self.opt_state, state)
updates, new_opt_state = self.tx.update(grads, self.opt_state.value, state)
new_params = optax.apply_updates(state, updates)
assert isinstance(new_params, nnx.State)

self.step.value += 1
nnx.update(self.model, new_params)
self.opt_state = new_opt_state
self.opt_state.value = new_opt_state
3 changes: 3 additions & 0 deletions flax/nnx/tests/deprecated_transforms_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,3 +368,6 @@ def __call__(self, x: jax.Array) -> jax.Array:
y = module(x)

assert y.shape == (1, 5, 3)

if __name__ == '__main__':
absltest.main()
21 changes: 0 additions & 21 deletions flax/nnx/tests/experimental_test.py

This file was deleted.

6 changes: 4 additions & 2 deletions flax/nnx/tests/filters_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.


from absl.testing import absltest

from flax import nnx
Expand All @@ -30,4 +29,7 @@ def __init__(self, rngs):
head_state = nnx.state(model, nnx.PathContains('head'))

self.assertIn('head', head_state)
self.assertNotIn('backbone', head_state)
self.assertNotIn('backbone', head_state)

if __name__ == '__main__':
absltest.main()
48 changes: 30 additions & 18 deletions flax/nnx/tests/graph_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from collections.abc import Callable
import dataclasses
from functools import partial
from threading import Thread
from typing import Any

from absl.testing import absltest
from absl.testing import absltest, parameterized
from flax import linen, nnx, struct
import jax
import jax.numpy as jnp
import pytest


class StatefulLinear(nnx.Module):
Expand Down Expand Up @@ -77,7 +77,7 @@ def test_unflatten_empty(self):

graphdef, state = nnx.split(g)

with pytest.raises(ValueError, match='Expected key'):
with self.assertRaisesRegex(ValueError, 'Expected key'):
nnx.graph.unflatten(graphdef, nnx.State({}))

def test_update_dynamic(self):
Expand Down Expand Up @@ -109,8 +109,8 @@ def test_update_static_inconsistent_types(self):
g = [a, 3, a, nnx.Param(4)]
g2 = [a, a, 3, nnx.Param(4)]

with pytest.raises(
ValueError, match='Trying to update a node with a different type'
with self.assertRaisesRegex(
ValueError, 'Trying to update a node with a different type'
):
nnx.graph.graph_update_static(g, g2)

Expand All @@ -130,7 +130,7 @@ def test_update_static_add_shared_error(self):
g = nnx.List([a, 3, a, nnx.Param(4)])
g2 = nnx.List([a, 3, a, nnx.Param(4), a])

with pytest.raises(ValueError, match='Trying to add a new node at path'):
with self.assertRaisesRegex(ValueError, 'Trying to add a new node at path'):
nnx.graph.graph_update_static(g, g2)

def test_module_list(self):
Expand Down Expand Up @@ -428,10 +428,10 @@ def f_pure(graphdef: nnx.graph.GraphDef[Foo], state):
def test_call_jit_update(self):
class Counter(nnx.Module):
def __init__(self):
self.count = jnp.zeros(())
self.count = nnx.Param(jnp.zeros(()))

def inc(self):
self.count += 1
self.count.value += 1
return 1

graph_state = nnx.split(Counter())
Expand All @@ -447,7 +447,7 @@ def update(graph_state: nnx.PureState[Counter]):

counter = nnx.merge(*graph_state)

self.assertEqual(counter.count, 2)
self.assertEqual(counter.count.value, 2)

def test_stateful_linear(self):
linear = StatefulLinear(3, 2, nnx.Rngs(0))
Expand Down Expand Up @@ -714,7 +714,7 @@ def test_to_tree_consistent_prefix(self):
pure_tree = nnx.to_tree(impure_tree, prefix=prefix)

prefix = (0, None, 1)
with pytest.raises(ValueError, match='Inconsistent aliasing detected'):
with self.assertRaisesRegex(ValueError, 'Inconsistent aliasing detected'):
nnx.to_tree(impure_tree, prefix=prefix)

def test_simple_vmap(self):
Expand Down Expand Up @@ -798,12 +798,24 @@ class SimplePyTreeModule(nnx.Module, experimental_pytree=True):
pass


@pytest.mark.parametrize(['x'], [(SimpleModule(),), (SimplePyTreeModule(),)])
def test_threading(x: nnx.Module):
class MyThread(Thread):
def run(self) -> None:
nnx.graph.split(x)
class TestThreading(parameterized.TestCase):

thread = MyThread()
thread.start()
thread.join()
@parameterized.parameters(
(SimpleModule,),
(SimplePyTreeModule,),
)
def test_threading(self, module_fn: Callable[[], nnx.Module]):
x = module_fn()

class MyThread(Thread):

def run(self) -> None:
nnx.graph.split(x)

thread = MyThread()
thread.start()
thread.join()


if __name__ == '__main__':
absltest.main()
49 changes: 19 additions & 30 deletions flax/nnx/tests/module_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,21 +12,20 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import dataclasses
from copy import deepcopy
import dataclasses
from typing import Any, TypeVar

from absl.testing import absltest
from flax import nnx
import jax
import jax.numpy as jnp
import numpy as np
import pytest

from flax import nnx

A = TypeVar('A')


class TestModule:
class TestModule(absltest.TestCase):
def test_has_module_state(self):
class Foo(nnx.Module): ...

Expand All @@ -39,9 +38,9 @@ def test_trace_level(self):

@jax.jit
def f():
with pytest.raises(
nnx.errors.TraceContextError,
match="Cannot mutate 'Dict' from different trace level",
with self.assertRaisesRegex(
nnx.errors.TraceContextError,
"Cannot mutate 'Dict' from different trace level",
):
m.a = 2

Expand Down Expand Up @@ -265,7 +264,7 @@ def __call__(self, x):

m = Foo()

with pytest.raises(ValueError, match='to be a Variable, got'):
with self.assertRaisesRegex(ValueError, 'to be a Variable, got'):
m(2)

def test_sow_wrong_collection(self):
Expand All @@ -280,7 +279,7 @@ def __call__(self, x):

m = Foo()

with pytest.raises(ValueError, match='to be of type'):
with self.assertRaisesRegex(ValueError, 'to be of type'):
m(2)

def test_update_static_state_submodules(self):
Expand Down Expand Up @@ -466,9 +465,12 @@ def __init__(self, din, dout, *, rngs: nnx.Rngs):

block = Block(2, 5, rngs=nnx.Rngs(0))

with pytest.raises(
ValueError,
match="Could not find at least one instance of the following attributes: {'unknown'}",
with self.assertRaisesRegex(
ValueError,
(
'Could not find at least one instance of the following attributes:'
" {'unknown'}"
),
):
block.set_attributes(
deterministic=True, use_running_average=True, unknown=True
Expand Down Expand Up @@ -662,26 +664,10 @@ def __init__(self, *, rngs: nnx.Rngs):
assert modules[1][0] == 'linear'
assert isinstance(modules[1][1], nnx.Linear)

def test_array_in_module(self):
class Foo(nnx.Module):
def __init__(self):
self.a = jnp.array(1.0)

foo = Foo()

graphdef, state = nnx.split(foo)

assert isinstance(state, nnx.State)
assert isinstance(state.a, jax.Array)

foo2 = nnx.merge(graphdef, state)

assert isinstance(foo2.a, jax.Array)

def test_state_in_module(self):
class Foo(nnx.Module):
def __init__(self):
self.a = nnx.State({'b': jnp.array(1.0)})
self.a = nnx.State({'b': nnx.Param(jnp.array(1.0))})

foo = Foo()

Expand All @@ -693,3 +679,6 @@ def __init__(self):
foo2 = nnx.merge(graphdef, state)

assert isinstance(foo2.a, nnx.State)

if __name__ == '__main__':
absltest.main()
Loading

0 comments on commit 90715be

Please sign in to comment.