-
This code shows a simple function that becomes 600 times slower when jit'ed. Note that the function is "static," meaning that it returns a constant value and has no inputs. This happens after compilation. If, instead of output, you input such a list into the function (a more typical pattern), accessing it is still very slow. Often "call overhead" is blamed for this sort of thing, but I cannot imagine how any call overhead could possibly account for 600X slowdown factor? What explains that? Placing the compute on CPU doesn't solve the issue. I wanted to report this as a bug, but I want to understand the situation better. Thanks.
|
Beta Was this translation helpful? Give feedback.
Replies: 4 comments 3 replies
-
Hi - thanks for the question! There are a couple things going on here:
values = [jax.device_put(i) for i in range(10000)] With these things in mind, I'd consider your example program working as expected. If I were trying to optimize the actual function being used here, I'd write it like this instead: def f():
return jnp.arange(10000) This returns a single array rather than a list of scalars, and avoids both problems (1) and (2) above. I suspect your real use-case is more complicated than this toy function, but hopefully the ideas mentioned here can help you figure out how to more effectively implement your own use-case. Hope that helps! |
Beta Was this translation helpful? Give feedback.
-
Thanks for replying, but this doesn't answer the question. Regarding (1) and (2): (2) I specified in my question that setting device to CPU doesn't solve this issue.
This code is testing jit Pytree access efficiency. I just want to know why that's very slow. |
Beta Was this translation helpful? Give feedback.
-
Thank you very much for your reply. That's really interesting! Actually, that's comparing apples and oranges. Those two expressions behave very differently. Jit performs really well on the first one
Results:
Here's my guess as to what's happening (I'm not sure how to disassemble/view the generated code; so, I'll take a guess here)... In the 2nd expression
The question is why is jit very slow in accessing Python functions(it seems)? I think we both agree that this is also an issue when the list is input into (rather than output from) the function; which is the more typical use-case for jax. Let me stress again: the differences in performance are extremely large by percentage, not negligible. You're absolutely right, this can be mitigated by batching (large arrays) and small Pytree sizes. So, let me say why I'm a little concerned about this. We have very deep ResNet models; over 100 layers. There is even a report of a 1000+ layer ResNet trained on CIFAR! (I know, that sounds really weird!). But people experiment with all sorts of things. Even for "shallow" ResNets, those convolutions are deep and complex enough that jax/lax offer specialized ops just for generalized convs. Numpy doesn't even have such an op. Consequently, it's very conceivable that without such special ops, jax code(Pytrees) might get more deep and complex, and the overhead could be non-negligible. It's not clear why jit needs to be like ~600X slower when using deep Pytrees and accessing python structures. We're comparing it to raw Python (already the slowest language on the planet!) and this is compiled jit making things even slower in some situations. Is there a huge cost to removing this overhead? What's causing it? |
Beta Was this translation helpful? Give feedback.
-
Jake, I really appreciate your help with this very much. The code below confirms that you're apparently right. In summary, device array creation has significant overhead relative to the case of just accessing existing device arrays, and significantly higher than raw Python code. However, this is a very Micky Mouse test intended to stress test just the overhead and in practice large arrays on device will hide (much of) this overhead. I believe you also said that the folks are working on making this array creation more efficient. I'm still a little surprised about the amount of difference between raw Python and device arrays, but I understand that memory block alignments and all can make a difference and my test is unrealistically tiny. This code inputs the arrays into the function after they're created (for Also note, this code is crashing on my CPU. Works fine on GPU. Not sure why. Thank you so much for Jax, Jit, and for your help.
|
Beta Was this translation helpful? Give feedback.
The compiled JIT execution only happens once all values are placed on the device, and this device placement happens in Python, the slowest language on the planet 😁
The non-jit version here is "construct a Python list of 10000 Python integers". The JIT version is "construct a Python list of 10000 Python integers, and then allocate a space on the XLA device for each of them & copy the bits over". There's no possible way the second pr…