Skip to content

Commit 449d084

Browse files
committed
small fixes
1 parent d7a28fa commit 449d084

File tree

2 files changed

+8
-6
lines changed

2 files changed

+8
-6
lines changed

jax/_src/interpreters/partial_eval.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1640,8 +1640,8 @@ class DynamicJaxprTracer(core.Tracer):
16401640
def __init__(self, trace: DynamicJaxprTrace,
16411641
aval: core.AbstractValue | core.AvalQDD,
16421642
val : Atom,
1643-
parent : TracingEqn | None = None,
1644-
line_info: source_info_util.SourceInfo | None = None):
1643+
line_info: source_info_util.SourceInfo | None = None,
1644+
parent : TracingEqn | None = None):
16451645
# TODO(dougalm): Remove aval. It's redundant now that we have val.
16461646
if isinstance(aval, core.AvalQDD):
16471647
assert aval.qdd is not None
@@ -1767,7 +1767,7 @@ class JaxprStackFrame:
17671767
gensym: Callable[[AbstractValue], Var]
17681768
constid_to_tracer: WeakValueDictionary[ConstId, DynamicJaxprTracer]
17691769
constvar_to_val: dict[Var, Any]
1770-
tracing_eqns: list[ReferenceType(TracingEqn)]
1770+
tracing_eqns: list[ReferenceType[TracingEqn]]
17711771
invars: list[Var]
17721772
effects: core.Effects
17731773
debug_info: core.DebugInfo
@@ -1944,6 +1944,7 @@ def __init__(self, debug_info: core.DebugInfo, parent_trace=None, lower=False,
19441944
self.frame = JaxprStackFrame(debug_info, auto_dce)
19451945
self.parent_trace = parent_trace
19461946

1947+
# TODO(dougalm): we might be able to remove this since the refcounting should be doing it for us
19471948
def invalidate(self):
19481949
# avoid cyclic refs
19491950
self.frame.constid_to_tracer = {}
@@ -1964,7 +1965,7 @@ def var_to_tracer(self, var, source_info, parent=None):
19641965
aval = var.aval
19651966
if aval.has_qdd:
19661967
aval = core.AvalQDD(aval, var.initial_qdd)
1967-
return DynamicJaxprTracer(self, aval, var, parent, source_info)
1968+
return DynamicJaxprTracer(self, aval, var, source_info, parent)
19681969

19691970
def new_arg(self, aval, source_info: SourceInfo):
19701971
var = self.frame.newvar(aval)
@@ -2022,7 +2023,7 @@ def _new_const(self, aval, c, source_info: SourceInfo) -> DynamicJaxprTracer:
20222023
return tracer
20232024

20242025
def finalize_const(self, var, constid):
2025-
self.frame.constvar_to_val.pop(var)
2026+
self.frame.constvar_to_val.pop(var, None)
20262027

20272028
def get_const(self, tracer) -> Any:
20282029
atom = tracer.val

tests/api_test.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5106,6 +5106,7 @@ def g():
51065106

51075107

51085108
def test_implicit_dce(self):
5109+
@api.jit
51095110
def foo(x):
51105111
const = np.zeros((300,))
51115112
r = weakref.ref(const)
@@ -5114,7 +5115,7 @@ def foo(x):
51145115
assert r() is None, "oops, the constant wasn't DCE'd"
51155116
return x + x
51165117

5117-
jax.make_jaxpr(foo, dce=True)(1.0)
5118+
foo(1.0)
51185119

51195120
class RematTest(jtu.JaxTestCase):
51205121

0 commit comments

Comments
 (0)