How to preserve sparsity when using jacfwd with custom JVP? #30174
-
I'm working on finite element analysis where the Jacobian matrices are naturally sparse (e.g., only ~1% nonzeros). Using dense Jacobians causes OOM errors for relatively small problems, making sparse matrices essential for feasibility. When using sparse matrices in custom JVP rules, import jax
import jax.numpy as jnp
import jax.experimental.sparse as jsp
from jax import custom_jvp
# sparse array with 199 nonzero elements
A_sparse = jsp.BCOO.fromdense(jnp.diag(jnp.ones(100)) + jnp.diag(jnp.ones(99), 1))
@custom_jvp
def f(x):
return A_sparse @ x
@f.defjvp
def f_jvp(primals, tangents):
x, = primals
x_dot, = tangents
return f(x), A_sparse @ x_dot
jac = jax.jacfwd(f)(jnp.ones(100))# dense array with 10000 elements (50x overhead) Is there a way to make autodiff preserve the BCOO format? I want to be able to use |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
Unfortunately, there's currently no built-in way to make Here are some options you can use: Solution 1: Use sparsejacThe most mature solution is the import sparsejac
import jax.experimental.sparse as jsp
import jax.numpy as jnp
# Define your sparsity pattern (same as your matrix structure)
A_sparse = jsp.BCOO.fromdense(jnp.diag(jnp.ones(100)) + jnp.diag(jnp.ones(99), 1))
@jax.custom_jvp
def f(x):
return A_sparse @ x
@f.defjvp
def f_jvp(primals, tangents):
x, = primals
x_dot, = tangents
return f(x), A_sparse @ x_dot
# Create sparse jacobian function with known sparsity pattern
@jax.jit
def sparse_jacfwd_fn(x):
with jax.ensure_compile_time_eval():
sparsity = jsp.BCOO.fromdense(jnp.diag(jnp.ones(100)) + jnp.diag(jnp.ones(99), 1))
jacfwd_fn = sparsejac.jacfwd(f, sparsity=sparsity)
return jacfwd_fn(x)
# This returns a BCOO sparse matrix instead of dense
sparse_jac = sparse_jacfwd_fn(jnp.ones(100)) Performance benchmarks show sparsejac can provide meaningful speedups for sparse problems, with execution times often 40-50% faster than dense implementations. Solution 2: Manual Sparse Jacobian ImplementationYou can implement sparse Jacobian computation manually by only computing the columns corresponding to your known sparsity pattern: def sparse_jacfwd_manual(f, x, sparsity_pattern):
"""
Manually compute sparse Jacobian for known sparsity pattern
sparsity_pattern: BCOO matrix indicating which entries are nonzero
"""
# Get indices of nonzero columns
unique_cols = jnp.unique(sparsity_pattern.indices[:, 1])
def single_column_jvp(col_idx):
# Create basis vector for this column
v = jnp.zeros_like(x)
v = v.at[col_idx].set(1.0)
_, jvp_result = jax.jvp(f, (x,), (v,))
return jvp_result
# Compute only the relevant columns
relevant_columns = jax.vmap(single_column_jvp)(unique_cols)
# Reconstruct sparse matrix
# This requires extracting only the nonzero entries from each column
# and building the BCOO data/indices arrays
return construct_sparse_jacobian(relevant_columns, sparsity_pattern)
def construct_sparse_jacobian(columns, sparsity_pattern):
"""Helper to reconstruct BCOO from computed columns"""
# Implementation depends on your specific sparsity structure
# For your case with tridiagonal-like structure, this can be optimized
pass Solution 3: Use JAX's Experimental Sparse FunctionsJAX provides experimental support for from jax.experimental import sparse
def f_raw(data, indices, shape=(100,)):
x_bcoo = sparse.BCOO((data, indices), shape=shape)
result = f(x_bcoo)
return result.data # Return only the data part
# You'll need to handle the BCOO decomposition manually
x_bcoo = A_sparse # Your sparse input
primals = (x_bcoo.data, x_bcoo.indices)
# This is much more complex and requires careful handling
jac_data = jax.jacfwd(f_raw, argnums=0)(x_bcoo.data, x_bcoo.indices) Solution 4: Domain-Specific Implementation for FEMSince you're working with finite element analysis, you likely know the exact sparsity structure of your Jacobian. You can leverage this knowledge: def fem_sparse_jacobian(f, x, element_connectivity):
"""
Specialized sparse Jacobian for FEM problems
element_connectivity: describes which DOFs are connected
"""
# Only compute derivatives for connected DOFs
nonzero_entries = []
indices = []
for element in element_connectivity:
for i in element:
for j in element:
if should_compute_entry(i, j): # Based on your FEM structure
# Compute specific Jacobian entry
entry = compute_jacobian_entry(f, x, i, j)
nonzero_entries.append(entry)
indices.append([i, j])
return jsp.BCOO((jnp.array(nonzero_entries), jnp.array(indices)), shape=(len(x), len(x))) Note: the methods in jax.experimental.sparse are experimental reference implementations, and not recommended for use in performance-critical applications. The JAX team is still working on improving sparse support. |
Beta Was this translation helpful? Give feedback.
Unfortunately, there's currently no built-in way to make
jacfwd
preserve BCOO sparsity in JAX. This issue is fundamental to howjacfwd
works: it uses forward-mode automatic differentiation by vmapping over basis vectors (identity matrix columns), which naturally produces dense output. JAX's BCOO sparse arrays are designed to be compatible with JAX transforms, includingjax.grad()
, butjacfwd
specifically creates dense Jacobians because it pushes forward an entire standard basis (identity matrix) at once using vmap. When you have a custom JVP that returns sparse arrays, the vmapping operation over all basis vectors inherently densifies the result. Sparse autodiff is fundamentally incompati…