-
Hi, Sorry if this question is already answered (I am sure it must be), but somehow I could not find any resource.
I can differentiate it explicitly using I get the following error:
|
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
There's no way to make jax.grad(lambda x: fv(x)[0])(x) If you're interested in computing the gradients for all outputs at once, you could do something like this: jax.jacrev(fv)(x) |
Beta Was this translation helpful? Give feedback.
There's no way to make
vmap(grad(fv))
work for this function, becausegrad
requires a scalar-output function, and yourfv
function is a vector output that does not correspond to any vector input. If you're interested on the gradient's effect on a single output, you could do something like this:If you're interested in computing the gradients for all outputs at once, you could do something like this: