Replies: 1 comment 1 reply
-
Regardless of the JAX version, I think this code is incorrect. You've marked The best solution probably is to avoid marking |
Beta Was this translation helpful? Give feedback.
1 reply
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
Greetings,
I have here a toy example of using inline=True when jitting a function:
Now suppose I run the following code:
on jax==0.5.2, this runs perfectly fine and outputs the correct answer. But on jax==0.6.1, this crashes because the compiler tries to mark the argument
i
in the second call ofinner_func
as static, buti
becomes a tracer after it is returned by the first call ofinner_func
. I assumed that setting inline=True would effectively copy-pasteinner_func
intoouter_func
so that this problem is sidestepped.Is this a bug in jax==0.6.1, or am I misunderstanding how inline=True is supposed to work? (Although there are obvious workarounds for this problem in this toy example, I would like to avoid them in my actual code if possible)
OS: Windows
Device: CPU
Beta Was this translation helpful? Give feedback.
All reactions