Performance of jax.grad
when taking the derivative with respect to many inputs
#26562
-
Hi, I am currently using Initially, I was taking the derivative with respect to each input by calling Now, my question is: why does this happen? Is there an underlying principle of automatic differentiation that explains this or is it mostly hyper-optimization of the Thank you for your help! Note: for the comparisons I did not JIT any function to try and make it "fair". |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 2 replies
-
The reason for this is likely that the gradient with respect to a single element requires just about as much computation as the gradient with respect to all three, because intermediate computations will be reused in the latter case. We can show this with a simplified example by looking at the jaxpr of the computation:
for I suspect your case is similar: the bulk of the expensive computation applies to all three outputs, so computing all three gradients at once is not all that more expensive than computing a single gradient. |
Beta Was this translation helpful? Give feedback.
The reason for this is likely that the gradient with respect to a single element requires just about as much computation as the gradient with respect to all three, because intermediate computations will be reused in the latter case. We can show this with a simplified example by looking at the jaxpr of the computation: