Skip to content

Getting the Hessian vector Product of a Flax NN output #10952

Answered by jakevdp
gitvicky asked this question in General
Discussion options

You must be logged in to vote

I just answered this question on StackOverflow (https://stackoverflow.com/a/72493336/2937831); copied here for completeness:


The issue is that your u_function maps a length-3 vector to a scalar. The first derivative of this is a length-3 vector, but the second derivative of this is a 3x3 hessian matrix, which you cannot compute via jax.grad, which is only designed for scalar-output functions. Fortunately JAX provides the jax.hessian transform to compute these general second derivatives:

u_XX = vmap(hessian(u_function, argnums=1), in_axes=(None, 0), out_axes=(0))(params, X)
print(u_XX.shape)
# (32, 3, 3)

Replies: 2 comments

Comment options

You must be logged in to vote
0 replies
Comment options

You must be logged in to vote
0 replies
Answer selected by gitvicky
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
3 participants