Skip to content

What is the "JAX" way of defining elementwise vector functions #14941

Answered by jakevdp
ipcamit asked this question in General
Discussion options

You must be logged in to vote

There's no way to make vmap(grad(fv)) work for this function, because grad requires a scalar-output function, and your fv function is a vector output that does not correspond to any vector input. If you're interested on the gradient's effect on a single output, you could do something like this:

jax.grad(lambda x: fv(x)[0])(x)

If you're interested in computing the gradients for all outputs at once, you could do something like this:

jax.jacrev(fv)(x)

Replies: 1 comment

Comment options

You must be logged in to vote
0 replies
Answer selected by ipcamit
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
2 participants