Jax conversion of custom jvps and vjps (and choice of forward or backward mode diff). #18320
-
Hello, I have a couple of questions relating to custom jvps and vjps, as well as forward and reverse mode differentiation in Jax. Suppose I have a chain of function compositions, and wish to invoke Jax's auto-diff. An example of this could be calculating the loss of a neural networks prediction for a batch of data. For illustrative purposes, lets assume that the neural network consists of conventional layers (i.e. those available in Jax libraries e.g.
Many thanks for your help, please do not hesitate to let me know if any of the questions are unclear. |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
(See e.g. here: |
Beta Was this translation helpful? Give feedback.
jax.grad
andjax.value_and_grad
always use reverse-mode AD. Forward-mode AD is exposed asjax.jacfwd
andjax.jvp
.(See e.g. here:
grad
doesn't work if VJP / reverse-mode isn't defined on a primitive.)