Skip to content

Why does jit make this function 600 times slower? (bug!) #16934

Closed Answered by jakevdp
RezaRob asked this question in Q&A
Discussion options

You must be logged in to vote

It's not clear why jit needs to be like ~600X slower when using deep Pytrees and accessing python structures. We're comparing it to raw Python (already the slowest language on the planet!) and this is compiled jit making things even slower in some situations.

The compiled JIT execution only happens once all values are placed on the device, and this device placement happens in Python, the slowest language on the planet 😁

The non-jit version here is "construct a Python list of 10000 Python integers". The JIT version is "construct a Python list of 10000 Python integers, and then allocate a space on the XLA device for each of them & copy the bits over". There's no possible way the second pr…

Replies: 4 comments 3 replies

Comment options

You must be logged in to vote
0 replies
Comment options

You must be logged in to vote
1 reply
@jakevdp
Comment options

Comment options

You must be logged in to vote
2 replies
@jakevdp
Comment options

@jakevdp
Comment options

Answer selected by RezaRob
Comment options

You must be logged in to vote
0 replies
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants