Skip to content

Spec undefined behaviour: cloning semantics for jit-related functions #7753

Answered by jekbradbury
cgarciae asked this question in General
Discussion options

You must be logged in to vote

I think this specific code is okay. In general performing in-place updates of Python objects is fine as long as you:

  • always pass such objects as explicit arguments and return values of every function that you wrap with a JAX transformation like jit, vmap, scan,named_call, etc. Failing to do this violates the "contract" of JAX transformations (they need the functions they transform to be functionally pure), and it's this requirement that gives rise to the wrapped transformation APIs in libraries like haiku, flax, and objax (which maintain state in mutable Python objects, but then lift these objects to be explicit arguments and return values of transformed functions).
  • never rely on the obj…

Replies: 3 comments 2 replies

Comment options

cgarciae
Aug 28, 2021
Collaborator Author

You must be logged in to vote
0 replies
Comment options

You must be logged in to vote
2 replies
@soraros
Comment options

@jakevdp
Comment options

Answer selected by jakevdp
Comment options

cgarciae
Aug 30, 2021
Collaborator Author

You must be logged in to vote
0 replies
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
4 participants
Converted from issue

This discussion was converted from issue #7748 on August 30, 2021 01:13.