diff --git a/flax/core/meta.py b/flax/core/meta.py index 27686a40b5..ca30524145 100644 --- a/flax/core/meta.py +++ b/flax/core/meta.py @@ -58,13 +58,13 @@ class AxisMetadata(Generic[A], metaclass=abc.ABCMeta): def unbox(self) -> A: """Returns the content of the AxisMetadata box. - Note that unlike ``meta.unbox`` the unbox call should recursively unbox + Note that unlike ``meta.unbox`` the unbox call should not recursively unbox metadata. It should simply return value that it wraps directly even if that value itself is an instance of AxisMetadata. In practise, AxisMetadata subclasses should be registered as PyTree nodes to support passing instances to JAX and Flax APIs. The leaves returned for this - note should correspond to the value returned by unbox. + node should correspond to the value returned by unbox. Returns: The unboxed value.