Using iterative solver inside custom_jvp rule causes error when jax.grad/jax.jacrev is called #30249
-
Summary: Details of toy implementation: def f(x, params):
x, y, z = x
a, b, c = params
f1 = a*x**2 + b*y**2 + c*z**2 - 3
f2 = a*x**2 + b*y**2 - c*z - 1
f3 = a*x + b*y + c*z - 3
return jnp.array([f1, f2, f3])
x = jnp.array([0.2, 0.3, 0.5]) # initial guess
@jax.custom_jvp
def nsolve(xi, params):
err = 0.1
tol = 1e-6
iter = 0
def Ffunc(x):
return f(x, params)
def dFfunc(x,dx):
r, dr = jax.jvp(Ffunc, (x,), (dx,))
return dr
while err > tol:
rhs = f(xi, params)
# Jmat = jax.jacfwd(f,argnums=0)(xi, params)
update = jax.scipy.sparse.linalg.gmres(lambda x: dFfunc(xi,x),-rhs)[0]
xi = xi + update
err = jnp.linalg.norm(update)
iter = iter+1
jax.debug.print("err={}",err)
return xi I define a custom_jvp that uses the implicit function theorem to solve a linear system for the converged solution x @nsolve.defjvp
def nsolve_jvp(primals, tangents):
xi, params, = primals
dxi, dparams, = tangents
xf = nsolve(xi, params)
def Ffunc(x):
return f(x, params)
Jmat = jax.jacfwd(Ffunc)(xf)
def dFfunc(x,dx):
# r, dr = jax.jvp(Ffunc, (x,), (dx,))
Jmat = jax.jacfwd(Ffunc)(x)
dr = Jmat @ dx
return dr
def Floc(p):
return f(xf, p)
res0, jvp_res0 = jax.jvp(Floc,(params,),(dparams,))
# tout = jax.scipy.sparse.linalg.gmres(lambda x: dFfunc(xf,x),-jvp_res0)[0] #error
tout = jax.scipy.sparse.linalg.gmres(Jmat, -jvp_res0)[0] #error
# tout = jnp.linalg.solve(Jmat,-jvp_res0) #[works]
primal_out = xf
tangent_out = tout
return (primal_out, tangent_out) Calling jax.jacfwd(nsolve,argnums=1)(x,p) works, Traceback (most recent call last): |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 2 replies
-
You're encountering this error because you're using When you call Replace your @jax.custom_vjp
def nsolve(xi, params):
# Your existing solver code
err = 0.1
tol = 1e-6
iter = 0
def Ffunc(x):
return f(x, params)
def dFfunc(x,dx):
r, dr = jax.jvp(Ffunc, (x,), (dx,))
return dr
while err > tol:
rhs = f(xi, params)
update = jax.scipy.sparse.linalg.gmres(lambda x: dFfunc(xi,x),-rhs)[0]
xi = xi + update
err = jnp.linalg.norm(update)
iter = iter+1
return xi
def nsolve_fwd(xi, params):
xf = nsolve(xi, params)
return xf, (xf, params)
def nsolve_bwd(res, g):
xf, params = res
def Ffunc(x):
return f(x, params)
Jmat = jax.jacfwd(Ffunc)(xf)
# Solve J^T @ lambda = g for lambda
lambda_val = jnp.linalg.solve(Jmat.T, g)
# Compute VJP w.r.t. params
def Floc(p):
return f(xf, p)
_, vjp_fun = jax.vjp(Floc, params)
g_params = vjp_fun(-lambda_val)[0]
# VJP w.r.t. initial guess is zero for converged solution
g_xi = jnp.zeros_like(xi)
return (g_xi, g_params)
nsolve.defvjp(nsolve_fwd, nsolve_bwd) For the implicit function theorem, if you have If you want to support both forward and reverse mode efficiently, you could define both The reason your original code worked with |
Beta Was this translation helpful? Give feedback.
-
Thanks for the getting back to me. The JAX documentation says that custom_jvp works for both forward and backward--- "Even though we defined only a JVP rule and no VJP rule, we can use both forward- and reverse-mode differentiation on f. JAX will automatically transpose the linear computation on tangent values from our custom JVP rule, computing the VJP as efficiently as if we had written the rule by hand.." https://docs.jax.dev/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html Also, in my example, reverse mode does not work only if I calculate the tangent vector with an iterative solver, but works if I use a direct matrix inversion |
Beta Was this translation helpful? Give feedback.
Ah, I think you're right about that, because I see now that the issue you're encountering is a known bug in JAX where
custom_jvp
withgmres
fails in reverse-mode differentiation. This is documented in issue #5309, which shows the exact same error:TypeError: Value UndefinedPrimal(ShapedArray(float32[3])) with type <class 'jax._src.interpreters.ad.UndefinedPrimal'> is not a valid JAX type
.custom_jvp
should support both forward and reverse mode through automatic transposition, but it looks like there's a specific bug with iterative solvers likegmres
. Noted in the issue: "both forward and reverse-mode work if we replace GMRES with something like np.linalg.solve".So I think the key differe…