Spec undefined behaviour: cloning semantics for jit-related functions #7753
-
Hey JAX team! @shoyer was kind enough to point out that I am depending on undefined behaviour on this twitter thread when updating incoming Pytrees in place. Example: jax.jit
def train_step(model, x, y, optimizer):
...
params = optimizer.update(grads, params)
...
return model, loss, optimizer The following is happening:
The problem I was not aware of is that
Is it worth creating a spec for this? Too early to know? Thanks! |
Beta Was this translation helpful? Give feedback.
Replies: 3 comments 2 replies
-
BTW: a discussion around this would immensely help to stir Treex's API as state management is a core issue. |
Beta Was this translation helpful? Give feedback.
-
I think this specific code is okay. In general performing in-place updates of Python objects is fine as long as you:
These two things aren't directly related, they're just both constraints relating to in-place object manipulation. |
Beta Was this translation helpful? Give feedback.
-
I see. Thanks @jekbradbury! I've been intuitively counting on the rules you describe. Are they established somewhere in the docs? If not maybe adding a brief section about "Object mutation inside JIT operations" with this information would be valuable. |
Beta Was this translation helpful? Give feedback.
I think this specific code is okay. In general performing in-place updates of Python objects is fine as long as you:
jit
,vmap
,scan
,named_call
, etc. Failing to do this violates the "contract" of JAX transformations (they need the functions they transform to be functionally pure), and it's this requirement that gives rise to the wrapped transformation APIs in libraries like haiku, flax, and objax (which maintain state in mutable Python objects, but then lift these objects to be explicit arguments and return values of transformed functions).