-
Hi, I am trying to get the second derivative of the output w.r.t the input of a neural network built using Flax. The network is structured as follows:
I can get the single derivative by using vmap over grad :
However, when I try to do this again to obtain the second derivative :
I get the folllowing error:
I tried using the hvp definition from the autodiff cookbook, but with params being an input to the function just wasnt sure how to proceed. Any help on this would be really appreciable. |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments
-
"I tried using the hvp definition from the autodiff cookbook, but with params being an input to the function just wasnt sure how to proceed." def model_with_params(x):
return model.apply(params, x)
# evaluate hvp of model_with_params |
Beta Was this translation helpful? Give feedback.
-
I just answered this question on StackOverflow (https://stackoverflow.com/a/72493336/2937831); copied here for completeness: The issue is that your u_XX = vmap(hessian(u_function, argnums=1), in_axes=(None, 0), out_axes=(0))(params, X)
print(u_XX.shape)
# (32, 3, 3) |
Beta Was this translation helpful? Give feedback.
I just answered this question on StackOverflow (https://stackoverflow.com/a/72493336/2937831); copied here for completeness:
The issue is that your
u_function
maps a length-3 vector to a scalar. The first derivative of this is a length-3 vector, but the second derivative of this is a 3x3 hessian matrix, which you cannot compute viajax.grad
, which is only designed for scalar-output functions. Fortunately JAX provides thejax.hessian
transform to compute these general second derivatives: