Why *does* this work? #18430
Unanswered
robertmaxton42
asked this question in
Q&A
Replies: 1 comment 4 replies
-
I think you've got it right: the reason this works is that even though
If you're going to continue with this approach, I might actually define |
Beta Was this translation helpful? Give feedback.
4 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
As part of a project I have ended up needing to roll my own wrapper around
jnp.vectorize
. Part of my code looks like this:(Yes, I know, I'm calling
jnp.vectorize
on an internal local function every time the vectorized function is called; I'm not terribly happy about that, but in the final result the above also includes asignature
argument I have to modify on the fly, including the output result structure, and I won't in general know the output structure until I run it once. Presumably I could avoid this by exploiting JAX's existing tracing capabilities but they aren't documented and I don't have time to really dig into that code. Fortunately, in my use case every vectorized function is only used once or twice, just on large inputs.)Anyway, I've used a
nonlocal
here because I don't see any other way to communicateout_treedef
back to_vectorized
, asjnp.vectorize
will always try and batch outputs into arrays and then promptly fail because JAX arrays don't support objects likePyTreeDef
. And it works!Except... why does it work? This is sort of the archetypal case of depending on a side effect, after all; I've defined a variable in
Quad.vectorize
's closure, I've not included it as an argument, and now I'm am passing a function that modifies that variable into a JAX transformation (jnp.vectorize
callsvmap
under the hood). Presumably, since JAX has to run_quadified
once to trace it, the tracers happen to provide the output structure that I need, but I'm still not entirely clear why it doesn't throw an error...(Not that I'm complaining, lol; I need this code to work!)
Beta Was this translation helpful? Give feedback.
All reactions