Skip to content

JAX Optimization Issue: Redundant Calculations Despite Unnecessary Conditions #22013

Answered by jakevdp
Qazalbash asked this question in Q&A
Discussion options

You must be logged in to vote

The jaxpr is a representation of the program's logic, and does not necessarily reflect the precise operations that will be done after the program is compiled. If you want to see the operations that the compiled program lowers to, you can do this using the ahead of time lowering and compilation APIs. For your case it would look like this:

print(jax.jit(model.log_prob).lower(xx).compile().as_text())

I'm not able to run this because there are a number of undefined names in the code you shared, but printing this should show you the actual operations used in the compiled code.

Replies: 1 comment 2 replies

Comment options

You must be logged in to vote
2 replies
@Qazalbash
Comment options

@jakevdp
Comment options

Answer selected by Qazalbash
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants