Skip to content

What does pure, lift, and sublift mean in JAX core? #18270

Answered by mattjj
leiteg asked this question in Q&A
Discussion options

You must be logged in to vote

Thanks for the question!

It may be tricky to answer in full detail, so instead I hope this summary answer is still useful:

  • trace.pure is called on non-Tracer values, e.g. numpy.ndarrays or Python builtin number type instances, and produces a Tracer for the trace meant to represent having trivial context added (e.g. a zero tangent for JVPTrace, a None batch dim for BatchTrace, etc);
  • trace.lift is similar but is always called on Tracer values, specifically ones that don't belong to the trace (e.g. if we're doing jvp-of-jvp, or vmap-of-jvp, then the inner one's trace.lift might be applied to a Tracer from the outer one), and also produces a Tracer for the trace meant to represent trivial co…

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@leiteg
Comment options

Answer selected by leiteg
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants