VJPs of VJPs? #18276
VJPs of VJPs?
#18276
Replies: 2 comments 1 reply
-
@jakevdp @patrick-kidger any help appreciated! |
Beta Was this translation helpful? Give feedback.
1 reply
-
It seems like #18383 is a better way to pose this question. |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
Reducing the scope of implementation to make the problem more understandable.
I have a function that takes in an input x, model parameters W, and produces and output y. Loss is computed as the MSE between x and y. My goal is to take the derivative of the "gradient of loss WRT the model" wrt to the output. To illustrate mathematically:
$$\frac{d}{dy}(\frac{d loss}{dW})$$
My thought process for implementation is to decompose the$\frac{d loss}{dy}$ and $\frac{dy}{d W}$ into a VJP with upstream gradient, then take the derivative of this:
$$\frac{d loss}{dW} = \frac{d loss}{d y} * \frac{d y}{d W}$$
Implemented as:
The goal for this is to compute some outer loop (upstream) gradient with respect to the output of an inner loop. Is it possible to nest VJPs as such – and will it compute the correct$\frac{d}{dy}(\frac{d loss}{dW})$ ?
Thanks, let me know if I can clarify anything!
Beta Was this translation helpful? Give feedback.
All reactions