How to know if a function does not depend on its arguments #17846
Answered
by
patrick-kidger
krzysztofrusek
asked this question in
Q&A
-
Hi, import jax
import jax.numpy as jnp
def constM(q):
return jnp.asarray([10.])
def varM(q):
return jnp.asarray([10.])+q
def mcmc(q,M):
# autodiff logic here to select algo based on the M
return q
q = jnp.asarray([1.,2,3,])
jax.jit(mcmc, static_argnums=(1))(q,constM)
jax.jit(mcmc, static_argnums=(1))(q,varM) tldr; I want to know if a function does not depend on its arguments (jacobian is 0) inside compiled function so I can use different algorithms or reuse calculated value. |
Beta Was this translation helpful? Give feedback.
Answered by
patrick-kidger
Sep 29, 2023
Replies: 1 comment 1 reply
-
You can use import jax
import jax.interpreters.partial_eval as pe
def is_constant(f, *inputs):
jaxpr = jax.make_jaxpr(f)(*inputs)
_, used_inputs = pe.dce_jaxpr(jaxpr.jaxpr, [True] * len(jaxpr.out_avals))
return all(not x for x in used_inputs)
def f(x, y):
return 5
def g(x, y):
return 5 + y
def h(x, y):
return x**2 + y
print(is_constant(f, 1, 2)) # True
print(is_constant(g, 1, 2)) # False
print(is_constant(h, 1, 2)) # False Fairly sure that |
Beta Was this translation helpful? Give feedback.
1 reply
Answer selected by
krzysztofrusek
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
You can use
jax.interpreters.partial_eval.dce_jaxpr
for this:Fairly sure that
dce_jaxpr
is only semi-public API, but in practice I use it for operations like this as well.