why are loops? #13970
Replies: 1 comment
-
@krahnikblis I think you're conflating device-compatible code with Python code that calls device-compatible (e.g. JIT'd) code. It's a common idiom to use control flow primitives like Another common source of confusion:
Under JAX's tracing functionality, the important values for JAX are all wrapped in I think your confusion stems from a misunderstanding around what values are tracked to create code that runs on device, and what things are executed when JAX traces a program. Here's my recommendation based on a cursory understanding of what you want to do: don't write the entire training loop in device-compatible code. Figure out the largest device compatible fragment, write that in JAX compat code, and then JIT that and run it in your training loop. Try and place any |
Beta Was this translation helpful? Give feedback.
Uh oh!
There was an error while loading. Please reload this page.
-
i'm struggling a bit with the ideology and implementation of JAX/Flax/related, especially around loops. i.e., JAX says "loops are bad" and "scan when you can" and my experience concurs - really long compile times (if it compiles at all - things will compile after up to 25 minutes, but after 5 minutes on TPU i cancel/restart/rewrite) and potentially slow run times.
BUT, if we should all "think in JAX" and use scan and fori_loop functions that are optimized for XLA, then why does google itself use loops in the source code, and in examples?
e.g., just about every training example i've seen in documentation (e.g., flax, jax, optax, haiku) or these discussion/issue forums is using an actual loop, not a scan/fori_loop implementation. i can't even find an example of a training function that uses scan or fori_loop, to know how to do it "right"
e.g., in tree_map, which i thought is meant to vectorize a function at a tree, the source code itself is using a loop over the items in the dict (literally "for xs in zip(*all_leaves)")
e.g., even inside of while_loop source code, which touts itself as different/better than python loop, it contains many python loops!
what gives?
i'm trying to figure out how to best convert ugly pytorch concepts to pure jax concepts; training programs are giving me a lot of grief because of their complexity. i thought it was due to having loops that the code isn't compiling/running right, but after rewriting everything as fori_loop, and then again rewriting as scan, nothing is different, the thing still won't compile and run quickly. only other thing i wonder is i'm using a lot of tree_map functions in the layers of a model (a meta-model that does a bunch of tree-mapped stuff to the large params of a main model), which, if those are really being done as sequential loops and not a true vectorized "tree map" as the function name would imply, then perhaps that's where i need to focus... but then how do you do vectorization of pytree ops if the vectorization function itself doesn't vectorize?
i guess i'll eventually figure it out, but, i'd really like to know to help guide my problem-solving and sanity, if loops are bad, AND the solution is to use special JAX alternatives, why is google still using loops everywhere, including within those special alternatives?
Beta Was this translation helpful? Give feedback.
All reactions