-
I was analyzing the JAXPR for the following piece of code and noticed something strange. Despite knowing that b will be False from the start, JAX is still performing calculations as if b were True. Specifically, the line:
already determines that b will be False, yet JAX continues to execute calculations that are only relevant if b were True. I'm curious to understand why this is happening and how I can avoid these redundant calculations. Code
from jax import lax, make_jaxpr
from jax import numpy as jnp
from numpyro import distributions as dist
from numpyro.distributions.util import promote_shapes, validate_sample
class TruncatedPowerLaw(dist.Distribution):
arg_constraints = {
"alpha": dist.constraints.real,
"low": dist.constraints.positive,
"high": dist.constraints.positive,
}
reparametrized_params = ["low", "high", "alpha"]
def __init__(
self,
alpha,
low=0.0,
high=1.0,
validate_args=None,
):
self.low, self.high, self.alpha = promote_shapes(low, high, alpha)
self._support = dist.constraints.interval(low, high)
batch_shape = lax.broadcast_shapes(
jnp.shape(low),
jnp.shape(high),
jnp.shape(alpha),
)
super(TruncatedPowerLaw, self).__init__(
batch_shape=batch_shape, validate_args=validate_args
)
@dist.constraints.dependent_property(is_discrete=False, event_dim=0)
def support(self):
return self._support
@validate_sample
def log_prob(self, value):
def logp_neg1(value):
return -jnp.log(value) - jnp.log(self.high) + jnp.log(self.low)
def logp(value):
log_value = jnp.log(value)
logp = self.alpha * log_value
beta = 1.0 + self.alpha
logp = logp + jnp.log(
beta / (jnp.power(self.high, beta) - jnp.power(self.low, beta))
)
return logp
return jnp.where(jnp.equal(self.alpha, -1.0), logp_neg1(value), logp(value))
if __name__ == "__main__":
model = TruncatedPowerLaw(alpha=2.0, low=1.0, high=10.0)
xx = jnp.linspace(0.1, 20.0, 100)
print(make_jaxpr(model.log_prob)(xx)) { lambda ; a:f32[100]. let
b:bool[] = eq 2.0 -1.0
c:f32[100] = log a
d:f32[100] = neg c
e:f32[] = log 10.0
f:f32[] = convert_element_type[new_dtype=float32 weak_type=False] e
g:f32[100] = sub d f
h:f32[] = log 1.0
i:f32[] = convert_element_type[new_dtype=float32 weak_type=False] h
j:f32[100] = add g i
k:f32[100] = log a
l:f32[100] = mul 2.0 k
m:f32[] = pow 10.0 3.0
n:f32[] = pow 1.0 3.0
o:f32[] = sub m n
p:f32[] = div 3.0 o
q:f32[] = log p
r:f32[] = convert_element_type[new_dtype=float32 weak_type=False] q
s:f32[100] = add l r
t:f32[100] = pjit[
name=_where
jaxpr={ lambda ; u:bool[] v:f32[100] w:f32[100]. let
x:bool[100] = broadcast_in_dim[broadcast_dimensions=() shape=(100,)] u
y:f32[100] = select_n x w v
in (y,) }
] b j s
in (t,) } |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 2 replies
-
The jaxpr is a representation of the program's logic, and does not necessarily reflect the precise operations that will be done after the program is compiled. If you want to see the operations that the compiled program lowers to, you can do this using the ahead of time lowering and compilation APIs. For your case it would look like this: print(jax.jit(model.log_prob).lower(xx).compile().as_text()) I'm not able to run this because there are a number of undefined names in the code you shared, but printing this should show you the actual operations used in the compiled code. |
Beta Was this translation helpful? Give feedback.
The jaxpr is a representation of the program's logic, and does not necessarily reflect the precise operations that will be done after the program is compiled. If you want to see the operations that the compiled program lowers to, you can do this using the ahead of time lowering and compilation APIs. For your case it would look like this:
I'm not able to run this because there are a number of undefined names in the code you shared, but printing this should show you the actual operations used in the compiled code.