Skip to content

How to know if a function does not depend on its arguments #17846

Discussion options

You must be logged in to vote

You can use jax.interpreters.partial_eval.dce_jaxpr for this:

import jax
import jax.interpreters.partial_eval as pe

def is_constant(f, *inputs):
    jaxpr = jax.make_jaxpr(f)(*inputs)
    _, used_inputs = pe.dce_jaxpr(jaxpr.jaxpr, [True] * len(jaxpr.out_avals))
    return all(not x for x in used_inputs)

def f(x, y):
    return 5

def g(x, y):
    return 5 + y

def h(x, y):
    return x**2 + y

print(is_constant(f, 1, 2))  # True
print(is_constant(g, 1, 2))  # False
print(is_constant(h, 1, 2))  # False

Fairly sure that dce_jaxpr is only semi-public API, but in practice I use it for operations like this as well.

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@krzysztofrusek
Comment options

Answer selected by krzysztofrusek
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants