Jitted function is fast to complete computations, but slow to return #17876
-
Hi, Firstly, apologies in advance for not providing a real code example. I have spent a long time trying to replicate my issue in a short snippet of code, but have so far failed to do so, any short example has not replicated this behaviour. I will try to keep the question general, and hopefully the behaviour is explainable with the details I provide. Here is some simple pseudocode to illustrate the issue I am having:
I am experiencing some strange behaviour where my jitted function runs very quickly right up until the Within this jitted function, I am using the lineax library which is a nice library for solving linear systems such as Finally, to add one detail which has particularly confused me, is that if I add one final operation to the end of my jitted function: I have spent a lot of time trying to understand why this behaviour happens but I'm struggling to understand how to even begin to approach it. When trying to Jit the function that all of this code lives within, then I experience the 2 minute delay on returning from the function 1 level out. If I avoid jitting my function, then I don't experience this 2 minute lag on returning my function, but I instead have to deal with the solving of Ax=b being over 2 minutes. I hope that my explanation was clear, please don't hesitate to ask for any other details. I understand this may be a very niche issue and I may have missed some detail that is crucial in understanding the behaviour. I myself have not yet even been able to replicate it within a single short snippet of code. Thank you for your time. |
Beta Was this translation helpful? Give feedback.
Replies: 3 comments 3 replies
-
The reason for this is that the |
Beta Was this translation helpful? Give feedback.
-
Thank you @jakevdp. Apologies for the long post, that makes much more sense now. One last point I am not clear on is why the following code:
runs suspiciously fast (2E-4 seconds), where x is replaced with zeros before returning. An obvious explanation is that jax is smart enough to understand that solving for x does not change the outcome of the function, and it should instead be filled with 0s right away. However, I checked the make_jaxpr for both versions of the function, and both involve the same component of solving for x. As I understand, both should take just as long to run. Thanks again for your help. |
Beta Was this translation helpful? Give feedback.
-
Thank you @jakevdp for your help on this issue, it has now been resolved. I've now run into a potentially related issue but I'm not able to make sense of it with what you provided above. Let me know if I should make a separate thread, as I'm not sure if it is related or not. I have a jitted function This following code set-up gives expected results:
In the above code, I'm evaluating the This appears to be a similar issue to what I posted above, where it seems like code is getting "stuck" on the return from a function. However, unlike in the previous case, my timing_function is not jitted, where the tracer pass would explain this apparent time difference. Running the Thanks again for all of your help, it has been really helpful to understand better how Jax and jitting works. |
Beta Was this translation helpful? Give feedback.
It looks like the lineax computation doesn’t contribute to the function output, so the compiler eliminates it.