-
After reading the JAX core from scratch and reading some parts of the JAX source code, I was confused as to what |
Beta Was this translation helpful? Give feedback.
Answered by
mattjj
Oct 25, 2023
Replies: 1 comment 1 reply
-
Thanks for the question! It may be tricky to answer in full detail, so instead I hope this summary answer is still useful:
The last one has no analogue in Autodidax (the JAX core from scratch tutorial). Does that help? |
Beta Was this translation helpful? Give feedback.
1 reply
Answer selected by
leiteg
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Thanks for the question!
It may be tricky to answer in full detail, so instead I hope this summary answer is still useful:
trace.pure
is called on non-Tracer
values, e.g. numpy.ndarrays or Python builtin number type instances, and produces aTracer
for thetrace
meant to represent having trivial context added (e.g. a zero tangent forJVPTrace
, aNone
batch dim forBatchTrace
, etc);trace.lift
is similar but is always called onTracer
values, specifically ones that don't belong to thetrace
(e.g. if we're doingjvp
-of-jvp
, orvmap
-of-jvp
, then the inner one'strace.lift
might be applied to aTracer
from the outer one), and also produces aTracer
for thetrace
meant to represent trivial co…