Understanding transpose rules #18346
-
I've gone over How JAX primitives work and Autodidax, but still don't have a good intuition for what a transpose rule actually represents mathematically. I understand that a JVP rule provides a recipe to compute Now the goal is to compute the VJP, |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 3 replies
-
The transposition rule for a linear function An extra bit of complexity is that some primitives are only linear in some of their inputs. We sometimes use the term "conditionally linear" to describe this. In this case the transpose rule only transposes the "linear part" of the function. To be a bit more precise: let's say Internally, we have a helper function ( |
Beta Was this translation helpful? Give feedback.
The transposition rule for a linear function$f$ computes its transpose $f^T$ at a particular point. Transposition is indeed used by jax's implementation of VJPs, where the linear function is the Jacobian map $J$ you mention.
An extra bit of complexity is that some primitives are only linear in some of their inputs. We sometimes use the term "conditionally linear" to describe this. In this case the transpose rule only transposes the "linear part" of the function.
To be a bit more precise: let's say$f$ takes two arguments, and is linear only in the second, for any value of the first. The transposition rule for $f$ then computes "the transpose of the function $f$ , restricted to a particula…