2nd order derivative is returning a Traced array with DynamicJaxprTrace instead of Traced array with BatchTrace #17465
-
I am working on a Physics-Informed Neural Network project using Jax, where I want to compute the second-order derivative of the output of a Neural network with respect to one of its derivatives(x). The code goes as below: def residual_net(self, params, E,E_physics,x, y):
ux_x =grad(self.ux_fn, argnums = 2)(params, E, x, y)
ux_xx =grad(grad(self.ux_fn, argnums = 2), argnums = 2)(params, E, x, y)
return ux_x, ux_xx Here, the function Note: The function residual_net is called using a vmap as: to compute the losses. |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 4 replies
-
The various tracer classes used are internal implementation details, and users shouldn't have to worry about the details. Can you add more detail to your question, in particular, how you're encountering the traced values? (a minimal reproducible example would be helpful here). If you're simply executing transformed functions, the inputs and outputs should be normal JAX arrays. Tracers are only used internally. |
Beta Was this translation helpful? Give feedback.
I think the second derivative is expected to be zero in this case, because your operation is a sequence of matmuls and relus: matmul is linear (so its second derivative is zero) and relu has a constant gradient everywhere that the gradient is defined (so its second derivative is zero everywhere).
A sequence of operations with zero second derivative will result in a function with zero second derivative.