diff --git a/flax/core/nn/normalization.py b/flax/core/nn/normalization.py index b044d24924..01c820bfcd 100644 --- a/flax/core/nn/normalization.py +++ b/flax/core/nn/normalization.py @@ -20,8 +20,8 @@ import jax.numpy as jnp -def _absolute_dims(rank, dims): - return tuple(rank + dim if dim < 0 else dim for dim in dims) +def _absolute_dims(ndim, dims): + return tuple(ndim + dim if dim < 0 else dim for dim in dims) def batch_norm(scope: Scope,