-
Hi, Looking at the source code I can see that jax.grad calls jax.value_and_grad and just throws the value away. This seems a bit wasteful. Example: My question is if this is actually what happens in practice or if perhaps the jit compiler is able to optimize away the unnecessary value computation? |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
Thanks for the question! Typically in JAX, performance-critical code is wrapped in |
Beta Was this translation helpful? Give feedback.
Thanks for the question!
Typically in JAX, performance-critical code is wrapped in
jax.jit
, which is passed to the compiler, and the compiler will do dead code elimination before execution. So in practice, throwing away this value during trace time does not matter.