Skip to content

Poor Jax Sparse Performance with GPU #8531

Answered by jakevdp
brosand asked this question in General
Discussion options

You must be logged in to vote

Hi - thanks for the question. In general, unless you have very sparse matrices, I would not expect sparse versions of matrix products to be faster than dense versions of matrix products, particularly on accelerators like GPU and TPU. This is not just a statement about JAX – I would expect this to hold for virtually any sparse and dense matrix algebra libraries.

Why? Accelerators like GPU and TPU are specifically designed for dense linear algebra, and take advantage of decades of engineering best practices for those specific operations. These optimizations rely on things like data locality guarantees & look aheads, the ability to blindly scan over standard data layouts in parallel. Sparse …

Replies: 2 comments 6 replies

Comment options

You must be logged in to vote
1 reply
@brosand
Comment options

Answer selected by brosand
Comment options

You must be logged in to vote
5 replies
@brosand
Comment options

@jakevdp
Comment options

@shailesh1729
Comment options

@brosand
Comment options

@shailesh1729
Comment options

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
3 participants