Skip to content

transpose_rules for nonlinear primitives. #16928

Answered by froystig
mavenlin asked this question in Q&A
Discussion options

You must be logged in to vote

Is vjp the only transform that calls the transpose rules under the hood? Any example that it is used in other types of transforms?

In addition to VJP, there is also the more direct jax.linear_transpose, which assumes linearity of the input function.

When it comes to transpose rules, the question comes, should we even implement transpose rule for this primitive? As it is not guaranteed to be linear.

Linearity only matters for primitives bound under the tangent computation of a JVP (i.e. during linearization), or for primitives bound under jax.linear_transpose. If you never expect to transpose invoke_p because it is never bound in a tangent computation, and because you don't plan to sup…

Replies: 1 comment

Comment options

You must be logged in to vote
0 replies
Answer selected by mavenlin
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants