From 68c47ccb9aef16be15b9dd178c755a40227f1146 Mon Sep 17 00:00:00 2001 From: Chengxi Ye Date: Thu, 9 Jun 2022 16:41:31 -0700 Subject: [PATCH] Update normalization.py rank --> number of dimensions --- flax/core/nn/normalization.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/flax/core/nn/normalization.py b/flax/core/nn/normalization.py index b044d24924..01c820bfcd 100644 --- a/flax/core/nn/normalization.py +++ b/flax/core/nn/normalization.py @@ -20,8 +20,8 @@ import jax.numpy as jnp -def _absolute_dims(rank, dims): - return tuple(rank + dim if dim < 0 else dim for dim in dims) +def _absolute_dims(ndim, dims): + return tuple(ndim + dim if dim < 0 else dim for dim in dims) def batch_norm(scope: Scope,