Taking Derivatives of Gradients w.r.t to intermediate variables? #18383
Unanswered
karan-dalal
asked this question in
Q&A
Replies: 1 comment 5 replies
-
You can only take derivatives with respect to function arguments. One way you can make this work is using the approach in #5336 (comment); i.e track a perterbation to your original variable; for example: def f(x, W, dx):
x_hat = (x @ W) @ A
loss = MSE(x, x_hat)
return x_hat + dx
loss_fn = grad(loss, argnums=2)
G = loss_fn(x, W, 0.0) The answer in your example is trivial, though, because |
Beta Was this translation helpful? Give feedback.
5 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
I have the following function, and somehow want to get dG/dx_hat, where G is the gradient of loss with respect to W. Assume A is some linear transformation matrix.
(Note this is a simplified project scope). I've tried reorganizing the computation as a VJP of a VJP (#18276), but that doesn't seem to yield the correct result. Could I try writing a function that returns G, then perturbing the x_hat? Any alternative suggestions?
Thank you!
Beta Was this translation helpful? Give feedback.
All reactions