Skip to content

Performance of lax.cond in JAX vs Python if-else #24110

Answered by ZedongPeng
ZedongPeng asked this question in Q&A
Discussion options

You must be logged in to vote

Hi @jakevdp,

Thank you for your suggestion! After running some tests, I found that the following implementation using while_loop significantly outperformed the standard Python while loop, running several times faster.

As I understand it, JAX compiles the cond and body functions when using while_loop, which speeds up the iterations inside the loop. However, if the while_loop needs to be executed multiple times, it seems important to define it outside the outer loop and jit the entire while_loop for optimal performance.

@jit
def f():
    def cond():
        pass
    def body():
        pass
    jax.lax.while_loop()

for i in range(5):
    f(x)

Replies: 2 comments 5 replies

Comment options

You must be logged in to vote
0 replies
Comment options

You must be logged in to vote
5 replies
@jakevdp
Comment options

@ZedongPeng
Comment options

@jakevdp
Comment options

@ZedongPeng
Comment options

Answer selected by ZedongPeng
@jakevdp
Comment options

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants