Primitives work & Reverse differentiation #12139
Replies: 1 comment
-
Hi - there are a bunch of examples of transpose function ouputs near the section you linked to: https://jax.readthedocs.io/en/latest/notebooks/How_JAX_primitives_work.html#transposition I'm not sure what additional information I can give here that would add clarity to those examples. Do you have any specific point of confusion after looking through that section? |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
I registered a primitive function func_primitive, for example: out=func_primitive(arg1, arg2, arg3). When I used grad(func_primitive), it reported "NotImplementedError: Transpose rule (for reverse-mode differentiation) for 'multiply_add' not implemented". Now I should implement the ad.primitive_transposes rule;
Now I can't understand the instructions provided by "https://jax.readthedocs.io/en/latest/notebooks/How_JAX_primitives_work.html#reverse-differentiation" and would like to ask:
When implementing the ad.primitive_transposes method, what should be returned?
Beta Was this translation helpful? Give feedback.
All reactions