Compared to list,the operation 'at' of jax array may take longer ? #17283
-
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
Hi, thanks for the question. There's a section of the FAQ that's relevant: FAQ: Is JAX Faster Than NumPy. Although this focuses on comparing JAX to NumPy, the discussion is relevant here as well. In short, JAX has relatively expensive per-operation dispatch costs. This doesn't really matter in practice, because in typical use you pay that dispatch cost only once per (JIT-compiled) program. By contrast, Python has very low per-operation dispatch cost. This is important because in Python you must pay this cost for every operation in your program: there is no built-in JIT compiler to get around that. Given this, when you compare very cheap operations like a single indexing op, you're essentially just comparing the single-operation dispatch cost, for which we know JAX will be slower than raw Python. But that has very little bearing on how JAX will perform in real-world use-cases, which generally involve longer sequences of JIT-compiled operations. |
Beta Was this translation helpful? Give feedback.
Hi, thanks for the question. There's a section of the FAQ that's relevant: FAQ: Is JAX Faster Than NumPy. Although this focuses on comparing JAX to NumPy, the discussion is relevant here as well.
In short, JAX has relatively expensive per-operation dispatch costs. This doesn't really matter in practice, because in typical use you pay that dispatch cost only once per (JIT-compiled) program.
By contrast, Python has very low per-operation dispatch cost. This is important because in Python you must pay this cost for every operation in your program: there is no built-in JIT compiler to get around that.
Given this, when you compare very cheap operations like a single indexing op, you're essenti…