Resources that go in depth on JAX performance optimization #14292
Unanswered
EelcoHoogendoorn
asked this question in
General
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
Despite working with JAX for about two years (and been programming just about anything for about 30 years, GPUs for 10+), I still feel like my conceptual model of JAX and its compilation model is quite weak. Im getting quite fluent in taking existing stateful code and JAXifying it and having it compile on the first try; but I do not feel very much in control of the performance im seeing.
I realised that recently, when working on my largest JAX codebase yet. Ive been trying out various different implementation strategies, and running them on a bunch of different hardware targets; and I cannot really make sense of the results. Sadly this is propietary code so I cannot share it here, but suffice to say its a nice mix of nested scans and vmaps and grads and jits all that.
There are so many questions that im in the dark about. What type of operations will get loop-fused? Does that depend on the size and number of vmaps that they operate within? Why does it sometimes matter for runtime performance if I pre-jit a subexpression of a larger jitted expression? How are my vmapped axes mapped to hardware thread blocks, and how do I make sure I dont mess that up?
I realise JAX is an evolving project and these questions may not have firm answers; but if there would be a book or something that goes well beyond the basics, which illustrates by example various performance gotchas and explains them in terms of the conceptual model of JAX, id love to hear about it. I imagine it does not exist yet... but it should. If anyone knows of a similar resource, thatd be great.
Beta Was this translation helpful? Give feedback.
All reactions