Skip to content

nnx.Dropout in combination with pmap returns an error #4190

Closed Answered by cgarciae
maxxxzdn asked this question in Q&A
Discussion options

You must be logged in to vote

Hey! I'm guessing you want to replicate the weights but have different RNGs, to do this you can use the nnx.split_rngs decorator to split the RNGs before entering pmap and, and use StateAxes to specify the parallelization axes for substates of your Module, in this case map RngState to 0 and the rest (...) to None:

state_axes = nnx.StateAxes({nnx.RngState: 0, ...: None})

@nnx.split_rngs(splits=1)
@nnx.pmap(in_axes=(state_axes, 0))
def forward(model, x):
  return model(x)

out = forward(model, jnp.ones((1, 16, 2)))

For more info, check out the Filters guide.

Replies: 2 comments

Comment options

You must be logged in to vote
0 replies
Answer selected by maxxxzdn
Comment options

You must be logged in to vote
0 replies
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants