Replies: 1 comment
-
For your first question, why wrapping a single If you're using the FFT within a sequence of other operations, JIT-compiling the full sequence of operations may yield improvement, but re-compiling an already compiled operation won't yield any speedups. For your second question: the problem is you're trying to index into non-jax arrays using traced indices. Try using this: a = jnp.array([5, 5, 5])
b = jnp.array([5, 5, 5]) then you should be able to re-express this in terms of |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
Hi Jax Team,
Thanks for the great package. I have a couple questions I have run into while implementing an algorithm that I was hoping a JAX expert could help me out with.
When using PMAP we get down to around 250ms per loop which is obviously around 8x faster than non-PMAP'ed.
Yet after JIT'ing the fft the speed doesn't change at all:
Unless I am mis understanding JIT, wouldn't the speed increase by more than 1ms when using the JIT'ed version of the function?
Are there any other JAX elements (not JIT or PMAP) that I should look at when trying to implement huge FFT's across TPUs?
Finally,
I am having some trouble with the lax scan function when my scanned function needs access to multiple arrays.
For example:
I tried converting the above into a lax scan by passing the a and b as tuples along with the i value, but it continued to throw errors. Is there a correct way to implement something like that (that's hopefully computationally efficient)?
-Sarah
Beta Was this translation helpful? Give feedback.
All reactions