Skip to content

How to preserve sparsity when using jacfwd with custom JVP? #30174

Answered by guy-singer
itk22 asked this question in Q&A
Discussion options

You must be logged in to vote

Unfortunately, there's currently no built-in way to make jacfwd preserve BCOO sparsity in JAX. This issue is fundamental to how jacfwd works: it uses forward-mode automatic differentiation by vmapping over basis vectors (identity matrix columns), which naturally produces dense output. JAX's BCOO sparse arrays are designed to be compatible with JAX transforms, including jax.grad(), but jacfwd specifically creates dense Jacobians because it pushes forward an entire standard basis (identity matrix) at once using vmap. When you have a custom JVP that returns sparse arrays, the vmapping operation over all basis vectors inherently densifies the result. Sparse autodiff is fundamentally incompati…

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@itk22
Comment options

Answer selected by itk22
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants