Skip to content

Get index in flax linen scan #3135

Answered by jheek
jakubMitura14 asked this question in Q&A
Discussion options

You must be logged in to vote

I think you have a few options:

  1. unroll (part of) the loop
  2. use a switch statement
  3. Use jit with a static_argnum for the iteration count and a normal python loop to do the iteration

What is the fastest option depends on hardware and the specific ops and shapes your are using. For GPUs the first and last option will almost certainly be the fastest

Replies: 1 comment 4 replies

Comment options

You must be logged in to vote
4 replies
@jakubMitura14
Comment options

@jheek
Comment options

Answer selected by jakubMitura14
@jakubMitura14
Comment options

@guillaumepourcel
Comment options

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
3 participants