Skip to content

Question about defining new JAX primitives #12730

Answered by mattjj
gg2uah asked this question in General
Discussion options

You must be logged in to vote

Thanks for the questions!

Q1) Right, custom_jvp/vjp aren't the right tool for this job. You want to define your own Primitive. This tutorial is the best documentation on how to do that. A few more comments about this below.

Q2) An opaque-to-JAX C++ JVP rule for your primitive can make sense; it's just another primitive, where some inputs and outputs are linear. See the example below.

Q3) Yes. If you want to JIT it, you can set up a CustomCall, as described in the tutorial linked above.

Here's an example which may make sense. In it, I'll use ordinary NumPy imported as onp to perform some operations which are totally opaque to JAX; as far as JAX concerned, those are just calls into opaque f…

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@gg2uah
Comment options

Answer selected by gg2uah
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
2 participants