Skip to content

Using iterative solver inside custom_jvp rule causes error when jax.grad/jax.jacrev is called #30249

Answered by guy-singer
ubaidali asked this question in Q&A
Discussion options

You must be logged in to vote

Ah, I think you're right about that, because I see now that the issue you're encountering is a known bug in JAX where custom_jvp with gmres fails in reverse-mode differentiation. This is documented in issue #5309, which shows the exact same error: TypeError: Value UndefinedPrimal(ShapedArray(float32[3])) with type <class 'jax._src.interpreters.ad.UndefinedPrimal'> is not a valid JAX type.

custom_jvp should support both forward and reverse mode through automatic transposition, but it looks like there's a specific bug with iterative solvers like gmres. Noted in the issue: "both forward and reverse-mode work if we replace GMRES with something like np.linalg.solve".

So I think the key differe…

Replies: 2 comments 2 replies

Comment options

You must be logged in to vote
0 replies
Comment options

You must be logged in to vote
2 replies
@guy-singer
Comment options

Answer selected by ubaidali
@ubaidali
Comment options

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants