Higher Order and Mixed Derivatives of Multivariate Functions (jax.experimental.jet) #25472
Replies: 3 comments
-
Pretty sure it is explained here at least a little bit better: #25700 "Each element of the outer list corresponds to the taylor series of one of the elements of the primal list, and the length of the element dictates the degree of the truncated Taylor polynomial. For example: primals = (x, y), series = ((dx, ddx), (dy, ddy))." and here is an example of someone calculating partial derivatives also: #5152 Although, I too don't quite understand how they are doing it. Did you ever figure this out? |
Beta Was this translation helpful? Give feedback.
-
I had another look at this, and with some copilot help was able to get an example that works I think.
The tricky part (if my understanding is correct of jet) is that jet is returning coefficients of an unscaled Taylor expansion, not just giving you back a gradient like jax.grad or jax.jacobian, and that you need to extract the partial derivatives from there. Maybe there is a cleaner method to extract partial derivatives, but for now this is all I can come up with. For this example I have just found |
Beta Was this translation helpful? Give feedback.
-
Nevermind my implementation, a much better answer is given here: |
Beta Was this translation helpful? Give feedback.
Uh oh!
There was an error while loading. Please reload this page.
-
Hello,
I'm Josiel de Souza, postdoc in Physics working with gravitational wave data analysis.
I'm interested to compute high order derivatives of functions.
I heard that using jax.experimental.jet is more efficient than using jax.grad recursively (the later is very slow), and, indeed, I saw that it works very well for 1-dimensional functions f(x). However I don't know to to use jax.experimental.jet to compute mixed derivatives$\frac{\partial^2f}{\partial x \partial y}$ as well as individual derivatives. I mean, I'd like to compute $\frac{\partial f}{\partial x}$ , $\frac{\partial f}{\partial y}$ , $\frac{\partial^2f}{\partial x\partial y}$ , $\frac{\partial^2 f}{\partial x^2}$ , $\frac{\partial^2f}{\partial y^2}$ and so on using jax.experimental.jet. I don't understand how the input series work in this case.
Could anyone help me with a practical example, for instance, for the function$sin^2x/y$ ?
Beta Was this translation helpful? Give feedback.
All reactions