How to keep hashable arguments static through a scan #16667
Unanswered
HenriLamarre
asked this question in
Q&A
Replies: 2 comments 3 replies
-
Thanks for the question! Any of the arguments that pass through the def loop_function(itermax):
iter = 0
nonstatic = 1
static = 2
def scanner_closure(a, x):
a, x = scanner((a[0], a[1], static), x)
return (a[0], a[1]), x
a, x = lax.scan(scanner_closure, (iter, nonstatic), xs=None, length=itermax) Now |
Beta Was this translation helpful? Give feedback.
3 replies
-
This is a very common issue. As a follow-up, you may like |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
Hi, disclaimer: I am new to JAX but it has been awesome so far! :D
I have a function ('jitfunc') that is jit-inherited and works properly. This function is being called in a while loop. The while loop itself is very slow whereas the function is fast. I would like to convert my while loop to either lax.while_loop, or lax.fori_loop but most likely lax.scan as it is differentiable and I the next step for this project is AI stuff.
So I have a maximal number of iterations itermax for my loop but the reason I use a while loop instead of a for loop is because sometimes, the jit-function can speed up the calculations and increase the iterand by more than 1.
In that scenario, I was planning to use scan, and have the function do nothing when it has reached itermax naturally (I think this is doable but I have no idea if it is optimal)
However, my issue comes from the fact that when I try running lax.scan, I get the error:
ValueError: Non-hashable static arguments are not supported, as this can lead to unexpected cache-misses. Static argument (index 6) of type for function in_loop is non-hashable.
From looking around online, it looks like using lax.scan makes my hashable static variables non-hashable? (Not sure about this one)
Is there a fix for what I am trying to do?
My code is not super readable so here is a minimal example that replicates my error:
Thanks! :)
Beta Was this translation helpful? Give feedback.
All reactions