Skip to content

Commit e682fa8

Browse files
committed
small simplification to asymptotic complexity of make_jaxpr_effects
1 parent 026f309 commit e682fa8

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

jax/_src/interpreters/partial_eval.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1727,7 +1727,7 @@ def get_referent(self):
17271727

17281728
def make_jaxpr_effects(constvars, invars, outvars, eqns) -> effects.Effects:
17291729
jaxpr_effects = set()
1730-
all_vars = [*constvars, *invars]
1730+
all_vars = {v: i for i, v in enumerate(it.chain(constvars, invars))}
17311731
for eqn in eqns:
17321732
for eff in eqn.effects:
17331733
if isinstance(eff, effects.JaxprInputEffect):
@@ -1738,14 +1738,14 @@ def make_jaxpr_effects(constvars, invars, outvars, eqns) -> effects.Effects:
17381738
"\n Jaxpr: "
17391739
f"{core.Jaxpr(constvars, invars, outvars, eqns, set())}")
17401740
invar = eqn.invars[eff.input_index]
1741-
if invar not in all_vars:
1741+
if (input_index := all_vars.get(invar)) is None:
17421742
raise ValueError(
17431743
f"`JaxprInputEffect` {eff} does not have "
17441744
f"corresponding input: {invar}."
17451745
f"\n Equation: {eqn}\n"
17461746
"\n Jaxpr: "
17471747
f"{core.Jaxpr(constvars, invars, outvars, eqns, set())}")
1748-
eff = eff.replace(input_index=all_vars.index(invar))
1748+
eff = eff.replace(input_index=input_index)
17491749
jaxpr_effects.add(eff)
17501750
return jaxpr_effects
17511751

0 commit comments

Comments
 (0)