A functional implementation of Fast Walsh Hadamard Transform using JAX #7308
Unanswered
shailesh1729
asked this question in
General
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
Hi,
I have come up with a JAX-based functional implementation of Walsh Hadamard Transform which works nicely with @jit. However, it involves three levels of
lax.while_loop
. I wanted to show the implementation and ask for comments if there is a way to express the logic in a better and more efficient manner.Code
First is a reference version which I wrote. It uses for, while loops and is slower. Took me a day to get the logic right. It may differ from other implementations by a scaling factor.
After this, I converted it to a functional version:
The three levels of
lax.while_loop
reflect the original logic however, is it possible to write it in a more expressive manner?Benchmarks
Some code that I ran in Jupyter notebook for benchmarking:
The gain of 11000x seems surreal. Normally, I get gains in the range of 50-200x during JIT compilation. Is this expected?
Here is another benchmark comparing a direct matrix-vector product to a transform operation
The gain for the matrix-vector product is relatively modest (just 6x). As I increase the number of columns in
y
, the gain reduces further. This seems to indicate that there is some opportunity to improve the JAX implementation.This implementation is available as part of CR-Sparse.
Beta Was this translation helpful? Give feedback.
All reactions