Replies: 1 comment
-
ef create_sparse_matrix(V, I, J, shape) # ----> They are all numpy arrays [Consider this as host call] @jax.custom_jvp @creation.def_jvp |
Beta Was this translation helpful? Give feedback.
-
ef create_sparse_matrix(V, I, J, shape) # ----> They are all numpy arrays [Consider this as host call] @jax.custom_jvp @creation.def_jvp |
Beta Was this translation helpful? Give feedback.
Uh oh!
There was an error while loading. Please reload this page.
-
Hi guys,
I want to use an external sparse linear solver alongside JAX, which requires a matrix-vector product function and an RHS vector.
I know that I have to use
pure_callback
+custom_jvp
on this solver to enable higher order autodiff. This can be done. However, the issue is the sparse-matrix creation.For saving memory, I have to create the sparse matrix from ordinary numpy arrays.
The main question is "How do I specify the result_shape for sparse matrices in the
pure_callback
?"It would be nice, if you point out if there is a something wrong with this approach!
Beta Was this translation helpful? Give feedback.
All reactions