Skip to content

Commit

Permalink
Merge pull request #2182 from yechengxi:patch-1
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 454138229
  • Loading branch information
Flax Authors committed Jun 10, 2022
2 parents c182c7d + 68c47cc commit 500d349
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions flax/core/nn/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
import jax.numpy as jnp


def _absolute_dims(rank, dims):
return tuple(rank + dim if dim < 0 else dim for dim in dims)
def _absolute_dims(ndim, dims):
return tuple(ndim + dim if dim < 0 else dim for dim in dims)


def batch_norm(scope: Scope,
Expand Down

0 comments on commit 500d349

Please sign in to comment.