-
Here's my understanding of transpose_rules. Transpose rules are registered for linear primitives. When a function is linearized with jvp, the linearized function can be transposed op by op, and that's how vjp is computed in jax.
Now since When it comes to transpose rules, the question comes, should we even implement transpose rule for this primitive? As it is not guaranteed to be linear. Or should we still give it a transpose rule by linearizing Thanks for reading my question. The above may sound wierd requirement for a primitive, but it could happen when we try to make the differentiation engine work for some custom calls. |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
In addition to VJP, there is also the more direct
Linearity only matters for primitives bound under the tangent computation of a JVP (i.e. during linearization), or for primitives bound under
You could give |
Beta Was this translation helpful? Give feedback.
In addition to VJP, there is also the more direct
jax.linear_transpose
, which assumes linearity of the input function.Linearity only matters for primitives bound under the tangent computation of a JVP (i.e. during linearization), or for primitives bound under
jax.linear_transpose
. If you never expect to transposeinvoke_p
because it is never bound in a tangent computation, and because you don't plan to sup…