Is it possible to write a SpMV kernel for TPU using jax Pallas? #17493
Replies: 2 comments 5 replies
-
It's definitely possible to write such a kernel, but unless you have some kind of highly structured sparsity, I doubt you'd be able to write a very performant kernel given the computational characteristics of a TPU. I know @sharadmv has thought a bit about this – he may be able to say more. |
Beta Was this translation helpful? Give feedback.
-
As Jake mentioned, TPUs are best at exploiting certain kinds of structured sparsity (block sparsity most easily). If you have arbitrary sparsity, it will be more difficult to utilize the hardware to speed it up. It is possible to express blocksparse matrix multiplies with Pallas. We have examples/documentation on how to do so on the way (the key is to use the |
Beta Was this translation helpful? Give feedback.
Uh oh!
There was an error while loading. Please reload this page.
-
For many cases of spiking neural network modeling, we are using sparse matrix, and heavily using SpMV operations. I am wondering how to customize a SpMV kernel for TPU devices using the jax Pallas? If it is possible, how can I achieve that?
So many thanks!
Beta Was this translation helpful? Give feedback.
All reactions