From a3bdee8aa00d771fd507abf063f2101399021df2 Mon Sep 17 00:00:00 2001 From: Ivy Zheng Date: Mon, 16 Sep 2024 16:10:26 -0700 Subject: [PATCH] Partially revert #4192 which sets back a bunch of previous merged pushes. PiperOrigin-RevId: 675327496 --- flax/core/meta.py | 14 +++++ flax/errors.py | 9 +++ flax/linen/spmd.py | 15 +++++ flax/nnx/bridge/variables.py | 30 ++++----- flax/nnx/bridge/wrappers.py | 53 ++++++++-------- flax/nnx/errors.py | 17 ------ flax/nnx/extract.py | 18 +++--- flax/nnx/object.py | 2 +- flax/nnx/spmd.py | 10 ++- flax/nnx/transforms/compilation.py | 1 + flax/nnx/transforms/iteration.py | 28 ++++++--- flax/nnx/variables.py | 98 +++++++++++------------------- tests/nnx/bridge/wrappers_test.py | 47 ++++++++++---- tests/nnx/module_test.py | 4 +- tests/nnx/rngs_test.py | 5 +- tests/nnx/spmd_test.py | 58 ++++++++++++++++++ tests/nnx/transforms_test.py | 23 ++++++- 17 files changed, 274 insertions(+), 158 deletions(-) delete mode 100644 flax/nnx/errors.py diff --git a/flax/core/meta.py b/flax/core/meta.py index 27686a40b5..531b463c7d 100644 --- a/flax/core/meta.py +++ b/flax/core/meta.py @@ -22,6 +22,7 @@ """ import abc +import dataclasses import functools from typing import Any, Generic, TypeVar from collections.abc import Callable @@ -287,6 +288,19 @@ def get_sharding(self, mesh: jax.sharding.Mesh) -> jax.sharding.Sharding: """Returns the ``NamedSharding`` for this partitioned value.""" return jax.sharding.NamedSharding(mesh, self.get_partition_spec()) + def to_nnx_metadata(self) -> dict[str, Any]: + """Return a dict of metadata that can translate into an `nnx.Variable`.""" + metadata = vars(self) + metadata['sharding'] = metadata.pop('names') + return metadata + + @classmethod + def from_nnx_metadata(cls, metadata: dict[str, Any]): + """Given a dict of `nnx.Variable` format metadata, create a `nn.Partitioned`.""" + metadata['names'] = metadata.pop('sharding') + fields = {x.name for x in dataclasses.fields(cls)} + return cls(**{k: v for k, v in metadata.items() if k in fields}) + def with_partitioning( fn: Callable[..., Any], diff --git a/flax/errors.py b/flax/errors.py index 7284c6e3fb..b2ecd1be69 100644 --- a/flax/errors.py +++ b/flax/errors.py @@ -64,6 +64,15 @@ def __reduce__(self): return (FlaxError, (str(self),)) +################################################# +# NNX errors # +################################################# + + +class TraceContextError(FlaxError): + pass + + ################################################# # lazy_init.py errors # ################################################# diff --git a/flax/linen/spmd.py b/flax/linen/spmd.py index 93afab7646..cd622bbdae 100644 --- a/flax/linen/spmd.py +++ b/flax/linen/spmd.py @@ -328,6 +328,21 @@ def unbox(self, apply_constraint=True) -> Any: else: return self.value + def to_nnx_metadata(self) -> dict[str, Any]: + """Return a dict of metadata that can translate into an `nnx.Variable`.""" + metadata = vars(self) + metadata['sharding'] = metadata.pop('names') + metadata['sharding_rules'] = metadata.pop('rules') + return metadata + + @classmethod + def from_nnx_metadata(cls, metadata: dict[str, Any]): + """Given a dict of `nnx.Variable` format metadata, create a `nn.LogicallyPartitioned`.""" + metadata['names'] = metadata.pop('sharding') + metadata['rules'] = metadata.pop('sharding_rules') + fields = {x.name for x in dataclasses.fields(cls)} + return cls(**{k: v for k, v in metadata.items() if k in fields}) + def with_logical_partitioning( fn: Callable[..., Any], diff --git a/flax/nnx/bridge/variables.py b/flax/nnx/bridge/variables.py index 9d8714274a..d73f645f3b 100644 --- a/flax/nnx/bridge/variables.py +++ b/flax/nnx/bridge/variables.py @@ -58,10 +58,10 @@ def variable_type_name(typ: tp.Type[variableslib.Variable[tp.Any]]) -> str: def register_variable_name_type_pair(name, typ, overwrite = False): - """Register a pair of variable type name (like Linen collections) and its NNX type.""" + """Register a pair of Linen collection name and its NNX type.""" if not overwrite and name in VariableTypeCache: raise ValueError(f'Name {name} already mapped to type {VariableTypeCache[name]}. ' - 'To overwrite, call with `overwrite=True`.') + 'To overwrite, call register_variable_name_type_pair() with `overwrite=True`.') VariableTypeCache[name] = typ @@ -85,8 +85,7 @@ def _variable_parents_count(t: type): class NNXMeta(struct.PyTreeNode, meta.AxisMetadata[A]): - """Default Flax metadata class for `nnx.VariableState`. - """ + """Default Flax metadata class for `nnx.VariableState`.""" var_type: type[variableslib.Variable[tp.Any]] = struct.field(pytree_node=False) value: Any = struct.field(pytree_node=True) @@ -110,10 +109,11 @@ def remove_axis(self, index: int, params: dict[Any, Any]) -> 'NNXMeta[A]': def to_linen_var(vs: variableslib.VariableState) -> meta.AxisMetadata: metadata = vs.get_metadata() if 'linen_meta_type' in metadata: - if metadata['linen_meta_type'] is not meta.Partitioned: - raise ValueError('Not supporting Linen metadata types other than nn.Partitioned') - return meta.Partitioned(vs.value, names=metadata['sharding'], mesh=metadata['mesh']) - return NNXMeta(vs.type, vs.value, vs.get_metadata()) + linen_type = metadata['linen_meta_type'] + if hasattr(linen_type, 'from_nnx_metadata'): + return linen_type.from_nnx_metadata({'value': vs.value, **metadata}) + return linen_type(vs.value, **metadata) + return NNXMeta(vs.type, vs.value, metadata) def get_col_name(keypath: tp.Sequence[Any]) -> str: @@ -124,15 +124,15 @@ def get_col_name(keypath: tp.Sequence[Any]) -> str: def to_nnx_var(col: str, x: meta.AxisMetadata | Any) -> variableslib.Variable: - """Convert a Linen variable to an NNX variable. - This process needs the collection name, - """ + """Convert a Linen variable to an NNX variable.""" vtype = variable_type(col) if isinstance(x, NNXMeta): assert vtype == x.var_type, f'Type stored in NNXMeta {x.var_type} != type inferred from collection name {vtype}' return x.var_type(x.value, **x.metadata) if isinstance(x, meta.AxisMetadata): - if isinstance(x, meta.Partitioned): - return vtype(x.value, sharding=x.names, mesh=x.mesh, linen_meta_type=meta.Partitioned) - raise ValueError('Not yet supporting metadata types other than nn.Partitioned and NNXMeta') - return vtype(x) + x_metadata = vars(x) + if hasattr(x, 'to_nnx_metadata'): + x_metadata = x.to_nnx_metadata() + assert hasattr(x, 'value') + return vtype(**x_metadata, linen_meta_type=type(x)) + return vtype(x) \ No newline at end of file diff --git a/flax/nnx/bridge/wrappers.py b/flax/nnx/bridge/wrappers.py index 20ac7a2601..d209d89819 100644 --- a/flax/nnx/bridge/wrappers.py +++ b/flax/nnx/bridge/wrappers.py @@ -74,7 +74,7 @@ def lazy_init(fn: Module | tp.Callable[..., tp.Any], *args, **kwargs): module = fn assert callable(fn) else: - if not (hasattr(fn, '__self__') and isinstance(fn.__self__, Module)): + if not hasattr(fn, '__self__') and isinstance(fn.__self__, Module): raise ValueError(f'{fn = } needs to be a method of an NNX Module.') module = fn.__self__ _set_initializing(module, True) @@ -124,6 +124,7 @@ def __init__( self.linen_collections: tuple[str, ...] = () def lazy_init(self, *args, **kwargs): + """A shortcut of calling `nnx.bridge.lazy_init()` upon this module.""" return lazy_init(self, *args, **kwargs) def __call__( @@ -224,28 +225,6 @@ class ToLinen(linen.Module): skip_rng: bool = False metadata_type: tp.Type = bv.NNXMeta - def update_variables(self, module): - """Store the NNX module's graph def and state inside Linen module variables.""" - gdef, state = nnx.split(module) - # Save the graph def. - if self.is_mutable_collection('nnx'): - self.put_variable('nnx', 'graphdef', gdef) - # Sort all the variable types. - types = set(jax.tree.leaves( - jax.tree.map(lambda x: x.type, state, - is_leaf=lambda x: isinstance(x, nnx.VariableState)))) - types = bv.sort_variable_types(types) - _, *state_by_types = nnx.split(module, *types) - # Each variable type goes to its own linen collection, and - # each attribute goes to its own linen variable - for typ, state in zip(types, state_by_types): - collection = bv.variable_type_name(typ) - if self.is_mutable_collection(collection): - for k, v in state.raw_mapping.items(): - v = jax.tree.map(bv.to_linen_var, v, - is_leaf=lambda x: isinstance(x, nnx.VariableState)) - self.put_variable(collection, k, v) - @linen.compact def __call__(self, *args, **kwargs): # init codepath @@ -255,7 +234,7 @@ def __call__(self, *args, **kwargs): module_kwargs |= dict(rngs=nnx.Rngs(**linen_rngs_dict(self))) module = self.nnx_class(*self.args, **module_kwargs) # TODO: add lazy_init here in case there's an `ToNNX` submodule under `module`. - self.update_variables(module) + self._update_variables(module) return module(*args, **kwargs) # apply codepath @@ -270,11 +249,33 @@ def __call__(self, *args, **kwargs): module = nnx.merge(gdef, nnx_state) nnx.reseed(module, **linen_rngs_dict(self)) # reseed with keys from linen apply call. out = module(*args, **kwargs) - self.update_variables(module) + self._update_variables(module) return out + def _update_variables(self, module): + """Store the NNX module's graph def and state inside Linen module variables.""" + gdef, state = nnx.split(module) + # Save the graph def. + if self.is_mutable_collection('nnx'): + self.put_variable('nnx', 'graphdef', gdef) + # Sort all the variable types. + types = set(jax.tree.leaves( + jax.tree.map(lambda x: x.type, state, + is_leaf=lambda x: isinstance(x, nnx.VariableState)))) + types = bv.sort_variable_types(types) + _, *state_by_types = nnx.split(module, *types) + # Each variable type goes to its own linen collection, and + # each attribute goes to its own linen variable + for typ, state in zip(types, state_by_types): + collection = bv.variable_type_name(typ) + if self.is_mutable_collection(collection): + for k, v in state.raw_mapping.items(): + v = jax.tree.map(bv.to_linen_var, v, + is_leaf=lambda x: isinstance(x, nnx.VariableState)) + self.put_variable(collection, k, v) + def to_linen(nnx_class: tp.Callable[..., Module], *args, name: str | None = None, **kwargs): - """Shortcut of `ToLinen` if user is not changing any of `ToLinen` default fields.""" + """Shortcut of `nnx.bridge.ToLinen` if user is not changing any of its default fields.""" return ToLinen(nnx_class, args=args, kwargs=kwargs, name=name) \ No newline at end of file diff --git a/flax/nnx/errors.py b/flax/nnx/errors.py deleted file mode 100644 index 41c7d4fab5..0000000000 --- a/flax/nnx/errors.py +++ /dev/null @@ -1,17 +0,0 @@ -# Copyright 2024 The Flax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -class TraceContextError(Exception): - pass diff --git a/flax/nnx/extract.py b/flax/nnx/extract.py index 6ecf6f2405..845544c307 100644 --- a/flax/nnx/extract.py +++ b/flax/nnx/extract.py @@ -22,7 +22,7 @@ from flax import struct from flax.nnx.object import Object -from flax.typing import MISSING, PathParts +from flax.typing import Missing, PathParts from flax.nnx import graph @@ -59,7 +59,7 @@ def extract_graph_nodes( pytree: A, /, *, - prefix: tp.Any = MISSING, + prefix: tp.Any = Missing, validate_fn: tp.Callable[[KeyPath, Prefix, Leaf], None] | None = None, ) -> ( tuple[A, tuple[tp.Any, ...]] @@ -101,7 +101,7 @@ def extract_graph_nodes( pytree_out = jax.tree.unflatten(treedef, leaves) - if prefix is MISSING: + if prefix is Missing: return pytree_out, tuple(nodes) # type: ignore[bad-return-type] else: return pytree_out, tuple(nodes), tuple(node_prefixes) # type: ignore[bad-return-type] @@ -330,12 +330,13 @@ def to_tree( tree, /, *, - prefix: tp.Any = MISSING, + prefix: tp.Any = Missing, split_fn: tp.Callable[ [graph.SplitContext, KeyPath, Prefix, Leaf], tp.Any ] = default_split_fn, map_non_graph_nodes: bool = False, ctxtag: str | None = None, + check_aliasing: bool = True, ) -> tp.Any: leaf_prefixes = broadcast_prefix( prefix, @@ -351,9 +352,10 @@ def to_tree( with graph.split_context(ctxtag) as split_ctx: for (keypath, leaf), leaf_prefix in zip(leaf_keys, leaf_prefixes): if graph.is_graph_node(leaf): - check_consistent_aliasing( - leaf, leaf_prefix, node_prefixes=node_prefixes - ) + if check_aliasing: + check_consistent_aliasing( + leaf, leaf_prefix, node_prefixes=node_prefixes + ) tree_node = split_fn(split_ctx, keypath, leaf_prefix, leaf) leaves_out.append(tree_node) else: @@ -381,7 +383,7 @@ def from_tree( tree: tp.Any, /, *, - prefix: tp.Any = MISSING, + prefix: tp.Any = Missing, merge_fn: tp.Callable[ [graph.MergeContext, KeyPath, Prefix, Leaf], tp.Any ] = merge_tree_node, diff --git a/flax/nnx/object.py b/flax/nnx/object.py index 9e14155108..f2714ff7fd 100644 --- a/flax/nnx/object.py +++ b/flax/nnx/object.py @@ -25,13 +25,13 @@ import numpy as np from flax.nnx import ( - errors, reprlib, tracers, ) from flax.nnx import graph from flax.nnx.variables import Variable, VariableState from flax.typing import Key +from flax import errors G = tp.TypeVar('G', bound='Object') diff --git a/flax/nnx/spmd.py b/flax/nnx/spmd.py index e18003276b..9b20d32381 100644 --- a/flax/nnx/spmd.py +++ b/flax/nnx/spmd.py @@ -44,7 +44,7 @@ def _add_axis(x: tp.Any): sharding.insert(index, axis_name) x.sharding = tuple(sharding) # type: ignore - x.add_axis(axis_name, index) + x.add_axis(index, axis_name) return x return jax.tree.map( @@ -61,7 +61,7 @@ def _remove_axis(x: tp.Any): sharding = list(x.sharding) assert sharding.pop(index) == axis_name x.sharding = tuple(sharding) - x.remove_axis(axis_name, index) + x.remove_axis(index, axis_name) return x return jax.tree.map( @@ -89,9 +89,15 @@ def _maybe_replicate(x): else: return None + def from_rules(sharding, sharding_rules): + rules = {alias: on_mesh for (alias, on_mesh) in sharding_rules} + return (rules[s] if s in rules else s for s in sharding) + def f(x): if isinstance(x, (variables.VariableState, variables.Variable)): if hasattr(x, 'sharding') and x.sharding: + if hasattr(x, 'sharding_rules') and x.sharding_rules: + return x.replace(PartitionSpec(*from_rules(x.sharding, x.sharding_rules))) return x.replace(PartitionSpec(*x.sharding)) else: return x.replace(_maybe_replicate(x.value)) diff --git a/flax/nnx/transforms/compilation.py b/flax/nnx/transforms/compilation.py index d715898ce0..1f63654d63 100644 --- a/flax/nnx/transforms/compilation.py +++ b/flax/nnx/transforms/compilation.py @@ -324,6 +324,7 @@ def jit_wrapper(*args, **kwargs): (args, kwargs), prefix=(in_shardings, kwarg_shardings), split_fn=_jit_split_fn, + check_aliasing=in_shardings is not None, ctxtag='jit', ) pure_args_out, pure_kwargs_out, pure_out = jitted_fn( diff --git a/flax/nnx/transforms/iteration.py b/flax/nnx/transforms/iteration.py index 36c351f34f..c169a91fa1 100644 --- a/flax/nnx/transforms/iteration.py +++ b/flax/nnx/transforms/iteration.py @@ -107,17 +107,25 @@ def _update_variable_sharding_metadata( ): def _update_axes_fn(tree_node): if isinstance(tree_node, extract.TreeNode) and isinstance( - tree_node.metatata, StateAxes + tree_node.metatata, (StateAxes, int) ): - graphdef_states_out: list[extract.GraphDefState] = [] - for graphdef_state, axis in zip( + if isinstance(tree_node.metatata, int): + graph_def_state = tree_node.graphdef_states[0] + assert isinstance(graph_def_state, extract.GraphDefState) + graphdef_state = axis_fn( + graph_def_state, tree_node.metatata, transform_metadata + ) + return tree_node.replace(graphdef_states=(graphdef_state,)) + else: + graphdef_states_out: list[extract.GraphDefState] = [] + for graphdef_state, axis in zip( tree_node.graphdef_states, tree_node.metatata.axes - ): - assert isinstance(graphdef_state, extract.GraphDefState) - if isinstance(axis, int): - graphdef_state = axis_fn(graphdef_state, axis, transform_metadata) - graphdef_states_out.append(graphdef_state) - return tree_node.replace(graphdef_states=tuple(graphdef_states_out)) + ): + assert isinstance(graphdef_state, extract.GraphDefState) + if isinstance(axis, int): + graphdef_state = axis_fn(graphdef_state, axis, transform_metadata) + graphdef_states_out.append(graphdef_state) + return tree_node.replace(graphdef_states=tuple(graphdef_states_out)) return tree_node return jax.tree.map( @@ -130,7 +138,7 @@ def _vmap_split_fn(ctx: graph.SplitContext, path, prefix, x): return extract.TreeNode.from_split( *ctx.split(x, *prefix.filters), metadata=prefix ) - return extract.TreeNode.from_split(*ctx.split(x)) + return extract.TreeNode.from_split(*ctx.split(x), metadata=prefix) @dataclasses.dataclass(eq=False) diff --git a/flax/nnx/variables.py b/flax/nnx/variables.py index 76805477f5..ee6c8a003b 100644 --- a/flax/nnx/variables.py +++ b/flax/nnx/variables.py @@ -22,7 +22,7 @@ import jax -from flax import nnx +from flax import errors from flax.nnx import reprlib, tracers from flax.typing import Missing import jax.tree_util as jtu @@ -36,8 +36,8 @@ CreateValueHook = tp.Callable[['Variable[A]', A], A] AxisName = str AxisIndex = int -AddAxisHook = tp.Callable[[V, AxisName, AxisIndex], None] -RemoveAxisHook = tp.Callable[[V, AxisName, AxisIndex], None] +AddAxisHook = tp.Callable[[V, AxisIndex, AxisName | None], None] +RemoveAxisHook = tp.Callable[[V, AxisIndex, AxisName | None], None] VariableTypeCache: dict[str, tp.Type[Variable[tp.Any]]] = {} @@ -150,67 +150,43 @@ def __init__( **metadata: tp.Any, ): vars(self)['_trace_state'] = tracers.TraceState() - if set_value_hooks: - if callable(set_value_hooks): - set_value_hooks = (set_value_hooks,) - else: - set_value_hooks = tuple(set_value_hooks) + if callable(set_value_hooks): + set_value_hooks = (set_value_hooks,) else: - set_value_hooks = () - if get_value_hooks: - if callable(get_value_hooks): - get_value_hooks = (get_value_hooks,) - else: - get_value_hooks = tuple(get_value_hooks) + set_value_hooks = tuple(set_value_hooks) + + if callable(get_value_hooks): + get_value_hooks = (get_value_hooks,) else: - get_value_hooks = () + get_value_hooks = tuple(get_value_hooks) - if create_value_hooks: - if callable(create_value_hooks): - create_value_hooks = (create_value_hooks,) - else: - create_value_hooks = tuple(create_value_hooks) + if callable(create_value_hooks): + create_value_hooks = (create_value_hooks,) else: - create_value_hooks = () + create_value_hooks = tuple(create_value_hooks) - if add_axis_hooks: - if callable(add_axis_hooks): - add_axis_hooks = (add_axis_hooks,) - else: - add_axis_hooks = tuple(add_axis_hooks) + if callable(add_axis_hooks): + add_axis_hooks = (add_axis_hooks,) else: - add_axis_hooks = () + add_axis_hooks = tuple(add_axis_hooks) - if remove_axis_hooks: - if callable(remove_axis_hooks): - remove_axis_hooks = (remove_axis_hooks,) - else: - remove_axis_hooks = tuple(remove_axis_hooks) + if callable(remove_axis_hooks): + remove_axis_hooks = (remove_axis_hooks,) else: - remove_axis_hooks = () + remove_axis_hooks = tuple(remove_axis_hooks) if isinstance(value, VariableMetadata): value_metadata = dict(value.metadata) - if set_value_hooks and value.set_value_hooks: + if value.set_value_hooks: set_value_hooks = set_value_hooks + value.set_value_hooks - elif value.set_value_hooks: - set_value_hooks = value.set_value_hooks - if get_value_hooks and value.get_value_hooks: + if value.get_value_hooks: get_value_hooks = get_value_hooks + value.get_value_hooks - elif value.get_value_hooks: - get_value_hooks = value.get_value_hooks - if create_value_hooks and value.create_value_hooks: + if value.create_value_hooks: create_value_hooks = create_value_hooks + value.create_value_hooks - elif value.create_value_hooks: - create_value_hooks = value.create_value_hooks - if add_axis_hooks and value.add_axis_hooks: + if value.add_axis_hooks: add_axis_hooks = add_axis_hooks + value.add_axis_hooks - elif value.add_axis_hooks: - add_axis_hooks = value.add_axis_hooks - if remove_axis_hooks and value.remove_axis_hooks: + if value.remove_axis_hooks: remove_axis_hooks = remove_axis_hooks + value.remove_axis_hooks - elif value.remove_axis_hooks: - remove_axis_hooks = value.remove_axis_hooks metadata.update(value_metadata) value = tp.cast(A, value.raw_value) @@ -259,7 +235,7 @@ def __setattr__(self, name: str, value: Any) -> None: def _setattr(self, name: str, value: tp.Any): if not self._trace_state.is_valid(): - raise nnx.errors.TraceContextError( + raise errors.TraceContextError( f'Cannot mutate {type(self).__name__} from a different trace level' ) @@ -318,13 +294,13 @@ def create_value(self, value: A): value = hook(self, value) return value - def add_axis(self, axis_name: AxisName, axis_index: AxisIndex): + def add_axis(self, axis_index: AxisIndex, axis_name: AxisName | None): for hook in self.add_axis_hooks: - hook(self, axis_name, axis_index) + hook(self, axis_index, axis_name) - def remove_axis(self, axis_name: AxisName, axis_index: AxisIndex): + def remove_axis(self, axis_index: AxisIndex, axis_name: AxisName | None): for hook in self.remove_axis_hooks: - hook(self, axis_name, axis_index) + hook(self, axis_index, axis_name) def __eq__(self, other: object) -> bool: return type(self) is type(other) and vars(other) == vars(self) @@ -418,11 +394,11 @@ def on_set_value(self, value: A) -> A: ... def on_create_value(self, value: A) -> A: ... def on_add_axis( - self: V, axis_name: AxisName, axis_index: AxisIndex + self: V, axis_index: AxisIndex, axis_name: AxisName | None ) -> V: ... def on_remove_axis( - self: V, axis_name: AxisName, axis_index: AxisIndex + self: V, axis_index: AxisIndex, axis_name: AxisName | None ) -> V: ... def __jax_array__(self): @@ -870,17 +846,13 @@ def get_metadata(self) -> dict[str, tp.Any]: del metadata['value'] return metadata - def add_axis(self, axis_name: AxisName, axis_index: AxisIndex): - if not hasattr(self, 'add_axis_hooks'): - raise ValueError(f'No add_axis_hooks found for VariableState: {self}') + def add_axis(self, axis_index: AxisIndex, axis_name: AxisName | None): for hook in self.add_axis_hooks: - hook(self, axis_name, axis_index) + hook(self, axis_index, axis_name) - def remove_axis(self, axis_name: AxisName, axis_index: AxisIndex): - if not hasattr(self, 'remove_axis_hooks'): - raise ValueError(f'No remove_axis_hooks found for VariableState: {self}') + def remove_axis(self, axis_index: AxisIndex, axis_name: AxisName | None): for hook in self.remove_axis_hooks: - hook(self, axis_name, axis_index) + hook(self, axis_index, axis_name) def _variable_state_flatten(x: VariableState[tp.Any], *, with_keys: bool): diff --git a/tests/nnx/bridge/wrappers_test.py b/tests/nnx/bridge/wrappers_test.py index 72d42eb6d4..27f2927fd9 100644 --- a/tests/nnx/bridge/wrappers_test.py +++ b/tests/nnx/bridge/wrappers_test.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os +os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=4' from absl.testing import absltest import flax @@ -24,6 +26,12 @@ class TestCompatibility(absltest.TestCase): + def setUp(self): + super().setUp() + dim1 = max(jax.device_count() // 2, 1) + device_mesh = np.array(jax.devices()).reshape(dim1, jax.device_count() // dim1) + self.mesh = jax.sharding.Mesh(devices=device_mesh, axis_names=('in', 'out')) + def test_functional(self): # Functional API for NNX Modules functional = bridge.functional(nnx.Linear)(32, 64) @@ -135,21 +143,35 @@ def vmap_fn(inner, x): def test_linen_to_nnx_metadata(self): linen_module = nn.Dense( features=64, - kernel_init=nn.with_partitioning(nn.initializers.lecun_normal(), ('in', 'out'))) + kernel_init=nn.with_partitioning(nn.initializers.lecun_normal(), ('in', 'out')), + bias_init=nn.with_logical_partitioning(nn.initializers.zeros_init(), ('out-alias',), + rules=(('out-alias', 'out'),)), + ) x = jax.numpy.ones((1, 32)) linen_vars = linen_module.init(jax.random.key(0), x) - nnx_model = bridge.ToNNX(linen_module, rngs=nnx.Rngs(0)).lazy_init(x) - # nn.Partitioned metadata box is translated into a valid nnx.Variable / VariableState box. + + @nnx.jit + def create_sharded_nnx_module(x): + model = bridge.lazy_init(bridge.ToNNX(linen_module, rngs=nnx.Rngs(0)), x) + state = nnx.state(model) + sharded_state = nnx.with_sharding_constraint(state, nnx.get_partition_spec(state)) + nnx.update(model, sharded_state) + return model + with self.mesh: + nnx_model = create_sharded_nnx_module(x) + + # nn.Partitioned metadata boxes translated into valid nnx.Variable boxes. self.assertIsInstance(linen_vars['params']['kernel'], nn.Partitioned) + self.assertIsInstance(linen_vars['params']['bias'], nn.LogicallyPartitioned) self.assertIsInstance(nnx_model.params['kernel'], nnx.Variable) - np.testing.assert_array_equal(linen_vars['params']['kernel'].value, - nnx_model.params['kernel'].value) assert nnx_model.params['kernel'].sharding == ('in', 'out') - _, nnx_state = nnx.split(nnx_model) - self.assertIsInstance(nnx_state['params']['kernel'], nnx.VariableState) - np.testing.assert_array_equal(linen_vars['params']['kernel'].value, - nnx_state['params']['kernel'].value) - assert nnx_state['params']['kernel'].sharding == ('in', 'out') + assert nnx_model.params['kernel'].value.sharding.is_equivalent_to( + jax.sharding.NamedSharding(self.mesh, jax.sharding.PartitionSpec('in', 'out')), ndim=2) + + assert nnx_model.params['bias'].sharding == ('out-alias',) + assert nnx_model.params['bias'].sharding_rules == (('out-alias', 'out'),) + assert nnx_model.params['bias'].value.sharding.is_equivalent_to( + jax.sharding.NamedSharding(self.mesh, jax.sharding.PartitionSpec('out',)), ndim=1) ################## @@ -306,7 +328,9 @@ class LinenMiddle(nn.Module): @nn.compact def __call__(self, x): dot = bridge.to_linen(NNXInner, x.shape[-1], self.dout, self.dropout_rate, name='dot') - b = self.param('b', nn.initializers.lecun_normal(), (1, self.dout)) + logical_init = nn.with_logical_partitioning( + nn.initializers.lecun_normal(), ('out-alias',), rules=(('out-alias', 'out'))) + b = self.param('b', logical_init, (1, self.dout)) return dot(x) + b class NNXOuter(nnx.Module): @@ -335,6 +359,7 @@ def __call__(self, x): self.assertIsInstance(w, nnx.Param) np.testing.assert_allclose(model(x), x @ w + b) assert hasattr(w, 'sharding') and w.sharding == ('in', 'out') + assert hasattr(b, 'sharding') and b.sharding == ('out-alias', ) def test_linen_nnx_linen(self): # TODO: add when we can safely `lazy_init` the NNX module inside `ToLinen` without diff --git a/tests/nnx/module_test.py b/tests/nnx/module_test.py index d5aeae08cd..a3f7bf8c22 100644 --- a/tests/nnx/module_test.py +++ b/tests/nnx/module_test.py @@ -17,7 +17,7 @@ from typing import Any, TypeVar from absl.testing import absltest -from flax import nnx +from flax import nnx, errors import jax import jax.numpy as jnp import numpy as np @@ -39,7 +39,7 @@ def test_trace_level(self): @jax.jit def f(): with self.assertRaisesRegex( - nnx.errors.TraceContextError, + errors.TraceContextError, "Cannot mutate 'Dict' from different trace level", ): m.a = 2 diff --git a/tests/nnx/rngs_test.py b/tests/nnx/rngs_test.py index 0e42918264..eeb65ccaed 100644 --- a/tests/nnx/rngs_test.py +++ b/tests/nnx/rngs_test.py @@ -21,6 +21,7 @@ from absl.testing import absltest from flax import nnx +from flax import errors class TestRngs(absltest.TestCase): @@ -58,7 +59,7 @@ def test_rng_trace_level_constraints(self): @jax.jit def f(): with self.assertRaisesRegex( - nnx.errors.TraceContextError, + errors.TraceContextError, 'Cannot call RngStream from a different trace level', ): rngs.params() @@ -76,7 +77,7 @@ def h(): self.assertIsInstance(rngs1, nnx.Rngs) with self.assertRaisesRegex( - nnx.errors.TraceContextError, + errors.TraceContextError, 'Cannot call RngStream from a different trace level', ): rngs1.params() diff --git a/tests/nnx/spmd_test.py b/tests/nnx/spmd_test.py index 15808e0800..6a202e8135 100644 --- a/tests/nnx/spmd_test.py +++ b/tests/nnx/spmd_test.py @@ -100,6 +100,64 @@ def __call__(self, x): assert state_spec.opt_state[0].mu['w'].value == PartitionSpec('row', 'col') assert state_spec.opt_state[0].nu['w'].value == PartitionSpec('row', 'col') + def test_add_remove_axis_in_transform(self): + test = self + kadds, kremoves, badds, bremoves = [], [], [], [] + class MLP(nnx.Module): + + @nnx.split_rngs(splits=5) + @nnx.vmap( + in_axes=(0, 0), + transform_metadata={nnx.PARTITION_NAME: 'layers'}, + ) + def __init__(self, rngs: nnx.Rngs): + self.linear = nnx.Linear( + 3, + 3, + kernel_init=nnx.with_metadata( + nnx.initializers.lecun_normal(), sharding=('din', 'dout'), + add_axis_hooks=lambda _, idx, name: kadds.append((idx, name)), + remove_axis_hooks=lambda _, idx, name: kremoves.append((idx, name)), + ), + bias_init=nnx.with_metadata( + nnx.initializers.zeros_init(), # no sharding annotation here! + add_axis_hooks=lambda _, idx, name: badds.append((idx, name)), + remove_axis_hooks=lambda _, idx, name: bremoves.append((idx, name)), + ), + rngs=rngs, + ) + + @nnx.scan( + in_axes=(0, nnx.Carry), + transform_metadata={nnx.PARTITION_NAME: 'layers'} + ) + def __call__(self, x: jax.Array): + x = self.linear(x) + # test sharding layer axes is not present inside scan + test.assertEqual(self.linear.kernel.shape, (3, 3)) + test.assertEqual(self.linear.kernel.sharding, ('din', 'dout')) + # at least a remove_axis was already called to remove the layer axis + test.assertEqual(kremoves[-1], (0, 'layers')) + test.assertEqual(bremoves[-1], (0, 'layers')) + return x, None + + m = MLP(rngs=nnx.Rngs(0)) + self.assertEqual(m.linear.kernel.shape, (5, 3, 3)) + self.assertEqual(m.linear.kernel.sharding, ('layers', 'din', 'dout')) + self.assertEqual(m.linear.bias.shape, (5, 3)) + # One add_axis called to add the `nnx.vmap` dimension + self.assertEqual(kadds, [(0, 'layers')]) + self.assertEqual(kremoves, []) + self.assertEqual(badds, [(0, 'layers')]) + self.assertEqual(bremoves, []) + + # One remove_axis and one add_axis called when in and out of `nnx.scan` + y = m(jnp.ones((5, 3))) + self.assertEqual(kadds, [(0, 'layers'), (0, 'layers')]) + self.assertEqual(kremoves, [(0, 'layers')]) + self.assertEqual(badds, [(0, 'layers'), (0, 'layers')]) + self.assertEqual(bremoves, [(0, 'layers')]) + if __name__ == '__main__': absltest.main() diff --git a/tests/nnx/transforms_test.py b/tests/nnx/transforms_test.py index be487628fe..824e7b6b0e 100644 --- a/tests/nnx/transforms_test.py +++ b/tests/nnx/transforms_test.py @@ -323,7 +323,7 @@ def f(m: Foo): def test_apply_shardings(self): n_devices = max(jax.local_device_count() // 2, 1) - devices = mesh_utils.create_device_mesh((n_devices, n_devices)) + devices = mesh_utils.create_device_mesh((n_devices, jax.local_device_count() // n_devices)) mesh = jax.sharding.Mesh(devices, ('a', 'b')) def sharding(*args): @@ -2235,6 +2235,27 @@ def forward(model, x): self.assertEqual(y.shape, (5, 4, 3)) + def test_metadata(self): + @nnx.vmap( + in_axes=(None,), + out_axes=0, + axis_size=5, + transform_metadata={nnx.spmd.PARTITION_NAME: 'c'}, + ) + def create_block(rngs: nnx.Rngs): + return nnx.Linear( + 16, + 32, + rngs=rngs, + kernel_init=nnx.with_partitioning( + nnx.initializers.lecun_normal(), ('a', 'b') + ), + ) + + m = create_block(nnx.Rngs(0)) + self.assertEqual(m.kernel.value.shape, (5, 16, 32)) + self.assertEqual(m.kernel.sharding, ('c', 'a', 'b')) + class TestPmap(absltest.TestCase):