Newton's method using JVP and an iterative solver #17975
Replies: 1 comment
-
That you're seeing multiple calls to Almost always the best way to write JAX code is to have express all of your numerical operations, and then wrap the whole thing in a single By the way, if you'd like a ready-made implementation of Newton's method (supporting all the varieties of solver you're discussing above) then you might like Optimistix. (Newly released!) |
Beta Was this translation helpful? Give feedback.
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
Hello everyone,
I'm working on an implementation of Newton's method using an iterative solver. As a recent JAX newbie I have some doubts (explained below) on whether my code complies with "the JAX way of thinking" and I would appreciate some insight from the experts.
Side note: I'm using JAX version 0.4.16 and running my code on a CPU only on a MacBook Pro with an M2 chip. Any possible GPU concerns remain a problem for future me.
Background
To give some background, I'm writing a solver for the Dyson equation - a non-linear Volterra integral equation, which in the end boils down to finding the roots of a function involving a bunch of multiplications of matrices with complex entries as the most computationally expensive step.
Simple example
For ease of discussion let's consider a simpler example first. Suppose I want to find the solutions to$f(x)=0$ for the following $f(x)$ , with $x$ being a vector of length 3:
Full Jacobian + inverse
The most naive approach would be to compute the$(n+1)$ -st iterate $x^{(n+1)}$ by evaluating the inverse of the Jacobian $J_f$ of $f$ evaluated at $x^{(n)}$
corresponding to the following code
where the number of iterations is fixed for sake of clarity.
Full Jacobian + linear solve
Alternatively, we rewrite the above expression as a linear system
where$\Delta x^{(n)} = x^{(n+1)} - x^{(n)}$ .
Thus, we obtain corrections to our guess by solving the linear system:
JVP + iterative linear solve
Both methods above involve computing the full Jacobian, which gets expensive quickly as the dimension of$x$ increases - as it does in my use case . The way to go then would be to instead solve the linear system iteratively (e.g., using GMRES or BiCGSTAB). In this case, the entire LHS of the linear system is provided via
jax.jvp
.The issue(s)
The JVP+iterative solve code is a couple of orders of magnitude slower than the full Jacobian codes for the simple example. By timing various parts of the code I noticed that only about 20% of the time is spent inside computing$f(x)$ and evaluating the Jacobian-vector-product (as expected). Running a profiler I see that the 10 calls to
built-in method jaxlib.xla_extension.compile
take up the vast majority of time for the simple example and also take up a significant portion of time for my more involved use case, with larger, complex-valued matrices. In the later case, the larger part of time is actually spent on many calls tobuilt-in method jaxlib.xla_extension.execute_sharded
.All this makes me think that I'm doing things in a very un-JAX way as something gets re-compiled in each iteration, taking up a lot of the time. Whether this is related to the 'sharding' is beyond my knowledge. Ultimately, I'm not even sure if this can be avoided at all.
EDIT: Upon further inspection, it seems as if sharding seems to be the main culprit, as the time it takes scales with the dimension of$x$ . Any idea on what to do about that?
Thanks in advance for any pointers, suggestions etc.
Also on a slightly unrelated note, no matter what the problem parameters, I always seem to get 5 calls of the
jvp
function pergmres
call, which is a bit curious.Beta Was this translation helpful? Give feedback.
All reactions