Optimizing XLA Compilation when using a tuple of functions as an argument in a scan function #24283
Unanswered
christophedessers
asked this question in
Q&A
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
Hello,
I’m encountering a significant increase in XLA compilation time when using a tuple of functions as an argument in a
scan
function which uses a switch to select some function to apply.Specifically, the issue arises when I scale the system by duplicating these functions. If I understand correctly, creating N functions would lead JAX to compile N functions. However, in my case, I’m only working with two distinct functions (
fct1
,fct2
) which are partial and both of them only have 1 value as a "curried" parameter. Therefore, although I replicate these functions many times, I hoped that JAX would see there are actually only two 2 unique functions, not N.Here is the simplified code :
Here’s when the problem arises: As the number of functions duplicates increases, the XLA compilation time for the
scan
function grows rapidly (for example, with 5000 replicas, compilation takes about 2 minutes). I suspect JAX doesn't understand that there are 2 unique functions.My questions:
Here are the results of the “xla_dump_to” flag containing the pre-optimization HLO files :
dump_Simple_test.zip
I’m new to JAX, so I might be missing some key optimization strategies. Any insights or suggestions would be greatly appreciated!
Thank you for your time.
Beta Was this translation helpful? Give feedback.
All reactions