Skip to content

A function with default arguments was much faster in JAX. How did it happen? How does JAX trace its gradient too? #18259

Answered by mattjj
ToshiyukiBandai asked this question in Q&A
Discussion options

You must be logged in to vote

Thanks for the question!

I would guess that jitted_f2(x) can be faster because the compiler gets to specialize on the values of alpha and beta as constants rather than taking them as arguments at execution time, whereas jitted_f(x, alpha, beta) basically means "ask the compiler to make an executable taking alpha and beta as arguments only to be specified when it's called".

We could check by looking at the optimized program, as described here.

What do you think?

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@ToshiyukiBandai
Comment options

Answer selected by ToshiyukiBandai
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants