Skip to content

Commit

Permalink
Update normalization.py
Browse files Browse the repository at this point in the history
rank --> number of dimensions
  • Loading branch information
yechengxi committed Jun 9, 2022
1 parent 19fc095 commit 68c47cc
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions flax/core/nn/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
import jax.numpy as jnp


def _absolute_dims(rank, dims):
return tuple(rank + dim if dim < 0 else dim for dim in dims)
def _absolute_dims(ndim, dims):
return tuple(ndim + dim if dim < 0 else dim for dim in dims)


def batch_norm(scope: Scope,
Expand Down

0 comments on commit 68c47cc

Please sign in to comment.