jax.errors.UnexpectedTracerError: Encountered an unexpected tracer #18512
Replies: 1 comment
-
chatGPT helps a lot!!!! The error you encountered, jax.errors.UnexpectedTracerError, is commonly caused by a side effect in JAX-transformed functions where an intermediate computation escapes the functional scope. This often happens due to: Modifying global or non-local state: JAX expects functions to be "pure", meaning they don't have side effects like modifying global state. Your one_batch_update function seems to update the self.params and self.opt_state as side effects, which could be causing the issue. Using Python control flow with JAX tracers: Sometimes, using Python control flow (like if statements or loops) with values that are JAX tracers can cause this error. To resolve this error: Avoid Side Effects: Refactor the function to avoid modifying any non-local state. Instead, return the updated values as outputs from your function. Use JAX Control Flow: If your code involves conditional logic or loops, consider using JAX's control flow operators (jax.lax.cond, jax.lax.while_loop, etc.) instead of Python's. Here's a refactored version of your one_batch_update function that avoids these issues: @partial(jax.jit, static_argnames=["self", "nn1", "nn2"])
def one_batch_update(self, t, x, xt, Penalty, nn1, nn2):
# Compute the loss and gradients
(Loss, Aux), grads = jax.value_and_grad(self.loss_fn, has_aux=True)(self.params, t, x, xt, Penalty, nn1, nn2)
# Update params using a functional approach
updates, opt_state = self.optim.update(grads, self.opt_state)
new_params = optax.apply_updates(self.params, updates)
# Return the updated parameters and optimization state, along with Loss and Aux
return new_params, opt_state, Loss, Aux In this version, self.params and self.opt_state are not modified directly within the function. Instead, the function returns the new parameters and optimizer state. You would then need to update self.params and self.opt_state outside this function, based on the returned values. This approach maintains the functional purity expected by JAX. |
Beta Was this translation helpful? Give feedback.
Uh oh!
There was an error while loading. Please reload this page.
-
I tried to jit the
one_batch_update()
function, then I occurred an error as reported. Thenn1, nn2
are 2 neural networks.t, x, xt
are the data,Penalty = {'r': 1e6, 'b': 1e3, 'i':1e3}
is a dict.Any comments are appreciated.
Beta Was this translation helpful? Give feedback.
All reactions