Skip to content

Is it safe to use dict.pop in jit, vmap, and other jax transformations? #27856

Answered by jakevdp
HeavyCrab asked this question in Q&A
Discussion options

You must be logged in to vote

Whether or not it's safe to use dict.pop within jax.jit depends on the context.

In the snippet you shared, it is safe: explicit dict arguments to jit-compiled functions will be passed through jax.tree_util flattening, and so the dict you are operating on within the function is a copy of the dict passed to the function, and mutating it will not affect the dict in the outer scope.

However, if you remove jax.jit or execute your function in a jax.disable_jit context, mutating the dict will mutate the dict in the outer scope. So that's something to be aware of if you're using this kind of pattern.

Replies: 2 comments

Comment options

You must be logged in to vote
0 replies
Comment options

You must be logged in to vote
0 replies
Answer selected by HeavyCrab
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants