-
We are using Jax for some simulations, often involving highly sparse matrices. On cpus Jax has a clear performance jump with sparse matrices and sparse evaluation. However when transferred to the GPU we see a slowdown, with sparse jax evaluation slower than dense jax. Is this expected? I imagine there may be some optimization at the lower levels that has not taken place yet. We have tested with square and rectangular matrices, as well as 1d arrays, all with similar results. I have attached the code for the rectangular matrix as an example. Is there any expectation that gpu sparse evaluation will improve? |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 6 replies
-
Hi - thanks for the question. In general, unless you have very sparse matrices, I would not expect sparse versions of matrix products to be faster than dense versions of matrix products, particularly on accelerators like GPU and TPU. This is not just a statement about JAX – I would expect this to hold for virtually any sparse and dense matrix algebra libraries. Why? Accelerators like GPU and TPU are specifically designed for dense linear algebra, and take advantage of decades of engineering best practices for those specific operations. These optimizations rely on things like data locality guarantees & look aheads, the ability to blindly scan over standard data layouts in parallel. Sparse methods don't satisfy any of these constraints (e.g. they tend to be very non-local), and so will be much slower. Again, this has nothing to do with JAX: this will be true of any sparse algorithm implemented on hardware fundamentally designed for dense operations. So why does JAX have experimental sparse support at all? Well, there are occasions when it is useful, particularly for extremely sparse matrices, or for situations where the dense computation could not be done because of memory constraints. If you're in that regime, I'd suggest using JAX sparse. If you're mainly concerned with 1000x1000 diagonal matrices, just use dense representations and let the hardware do its thing. That being said, you can expect JAX sparse operations on GPU to be an order of magnitude faster in the very near future. We're actively working on GPU lowerings to cusparse for supported operations; the first part of that work will hopefully land in the main branch this afternoon (see #8514) |
Beta Was this translation helpful? Give feedback.
-
@brosand, looking at your code, it appears that your sparse matrices are structured in the sense that they contain only a single sub/super diagonal which have non-zero entries while the rest of the matrix is all 0. Have you considered a different approach where structured matrices can be represented as linear operators which can provide custom and fast implementation of The |
Beta Was this translation helpful? Give feedback.
Hi - thanks for the question. In general, unless you have very sparse matrices, I would not expect sparse versions of matrix products to be faster than dense versions of matrix products, particularly on accelerators like GPU and TPU. This is not just a statement about JAX – I would expect this to hold for virtually any sparse and dense matrix algebra libraries.
Why? Accelerators like GPU and TPU are specifically designed for dense linear algebra, and take advantage of decades of engineering best practices for those specific operations. These optimizations rely on things like data locality guarantees & look aheads, the ability to blindly scan over standard data layouts in parallel. Sparse …