Some sparse recovery algorithms built using JAX #6953
-
Hello, I am new to JAX. I have built some sparse recovery algorithms using JAX as part of a package CR.Sparse.
My biggest apprehension was whether I would be able to turn these algorithms around in functional programming style so that they can take full advantage of XLA JIT. So far, the results have been positive and it has been quite a learning experience. I hope some of you may find this work interesting. Please forgive me if the documentation is still a bit clunky. I would be happy to hear any feedback. |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
This is great, thanks for sharing! I took a look through the repo and the documentation... all the code looks pretty idiomatic & well-documented. There are a few places where you fall back to scipy routines (e.g. in You might consider adding your package to this list for more exposure: https://github.com/n2cholas/awesome-jax |
Beta Was this translation helpful? Give feedback.
This is great, thanks for sharing! I took a look through the repo and the documentation... all the code looks pretty idiomatic & well-documented. There are a few places where you fall back to scipy routines (e.g. in
distance.py
) so I suspect those will give you issues with JAX transformations, but it's good to see that the core routines are JIT compatible and show solid performance!You might consider adding your package to this list for more exposure: https://github.com/n2cholas/awesome-jax