Root finding options in JAX with adaptive bounds where .grad and .vmap work? #16675
Replies: 3 comments 1 reply
-
I'm now working on a solution that utilizes jax.scipy.optimize.minimize however the documentation states, 'minimize supports jit() compilation. It does not yet support differentiation or arguments in the form of multi-dimensional arrays, but support for both is planned.' |
Beta Was this translation helpful? Give feedback.
-
Looks like there was a misunderstanding here. Jumping to the conclusion:
|
Beta Was this translation helpful? Give feedback.
-
Bit late, but you might find Optimistix interesting. This is our latest scientific computing libary in JAX, and includes tackling root finding. |
Beta Was this translation helpful? Give feedback.
Uh oh!
There was an error while loading. Please reload this page.
-
Hi,
I've written a few versions of a code block that all work to find roots, but I keep running into trouble when using vmap or grad on them. I do see jaxopt has bisection method but that is limited. In this case I have a reasonably well behaved function where there is a single root that is positive real (of all the roots), but it does have the issue that shortly after the root the function rises quickly to numerical infinity. So I can set the lower bound at 0 but the upper bound has to be adaptive. I will share one of the many code versions that works in python on single data point but in this case fails on application of vmap due to abstract tracer value due to the boolean:
Any suggestions for a strategy to approach this problem that's better than writing from the ground up would be appreciated.
Beta Was this translation helpful? Give feedback.
All reactions