-
Hey, I am new to Jax - up to now I've been doing some fairly hobby-level coding using mygrad, which I think I'm now reaching the limits of. I have one main Q about jax before I dive in - can I implement metagradient descent? That is can I treat some function that implements gradient descent using jax as an opaque solver for some outer gradient descent algorithm? While JAX's grad tracers support this or will they get confused as the outer ones are included in the inner's computation graph? Am I barking mad? |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 4 replies
-
In general JAX transformations are composable, so you can use autodiff to compute the gradient of another function that uses autodiff. That said, there may be some limitations: for example, JAX does not implement reverse-mode autodiff of unbounded loops, so if your inner solver uses a |
Beta Was this translation helpful? Give feedback.
In general JAX transformations are composable, so you can use autodiff to compute the gradient of another function that uses autodiff.
That said, there may be some limitations: for example, JAX does not implement reverse-mode autodiff of unbounded loops, so if your inner solver uses a
while_loop
for convergence, you'll be limited to forward-mode autodiff.