Question about hessians for custom operations from C++ #17550
-
Hi, I'm interested in using JAX for finite element simulation. Some of the building blocks of finite elements (e.g. interpolation, integration) are expensive linear operators. I have an existing C++/CUDA library with optimized implementations of these operations that I was hoping to use from within JAX (rather than rewrite everything from scratch), but many of the guides I see related to C++ extensions and custom derivatives only describe how to register a VJP and JVP for these custom operations. Are custom operations from C++ limited to first derivatives, or is there a way to let JAX know that certain operations are linear (so that any higher order derivatives are identically zero)? I love JAX's ability to easily compute hessians, and would like to still be able to take two derivatives of code using the custom C++ interpolation / integration implementations. Thanks |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 8 replies
-
If you want your custom kernels to be compatible with JAX autodiff, you should register |
Beta Was this translation helpful? Give feedback.
I didn't realize you were trying to call GPU code from JAX. In that case, these docs would be more relevant than
pure_callback
, which is designed for callbacks to the CPU host: https://jax.readthedocs.io/en/latest/Custom_Operation_for_GPUs.html