A function with default arguments was much faster in JAX. How did it happen? How does JAX trace its gradient too? #18259
-
Hi all, I experienced a performance gain if I used default arguments in a function. This is related to the discussion on stack overflow (https://shorturl.at/ahmRS). In the example below, I compared the performance of a function with and without default arguments, and the former was 94.3 us and the latter was 220 us on my computer. So, I checked their jaxpr and noticed that XLA optimized
|
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
Thanks for the question! I would guess that We could check by looking at the optimized program, as described here. What do you think? |
Beta Was this translation helpful? Give feedback.
Thanks for the question!
I would guess that
jitted_f2(x)
can be faster because the compiler gets to specialize on the values ofalpha
andbeta
as constants rather than taking them as arguments at execution time, whereasjitted_f(x, alpha, beta)
basically means "ask the compiler to make an executable takingalpha
andbeta
as arguments only to be specified when it's called".We could check by looking at the optimized program, as described here.
What do you think?