@@ -1640,8 +1640,8 @@ class DynamicJaxprTracer(core.Tracer):
1640
1640
def __init__ (self , trace : DynamicJaxprTrace ,
1641
1641
aval : core .AbstractValue | core .AvalQDD ,
1642
1642
val : Atom ,
1643
- parent : TracingEqn | None = None ,
1644
- line_info : source_info_util . SourceInfo | None = None ):
1643
+ line_info : source_info_util . SourceInfo | None = None ,
1644
+ parent : TracingEqn | None = None ):
1645
1645
# TODO(dougalm): Remove aval. It's redundant now that we have val.
1646
1646
if isinstance (aval , core .AvalQDD ):
1647
1647
assert aval .qdd is not None
@@ -1767,7 +1767,7 @@ class JaxprStackFrame:
1767
1767
gensym : Callable [[AbstractValue ], Var ]
1768
1768
constid_to_tracer : WeakValueDictionary [ConstId , DynamicJaxprTracer ]
1769
1769
constvar_to_val : dict [Var , Any ]
1770
- tracing_eqns : list [ReferenceType ( TracingEqn ) ]
1770
+ tracing_eqns : list [ReferenceType [ TracingEqn ] ]
1771
1771
invars : list [Var ]
1772
1772
effects : core .Effects
1773
1773
debug_info : core .DebugInfo
@@ -1944,6 +1944,7 @@ def __init__(self, debug_info: core.DebugInfo, parent_trace=None, lower=False,
1944
1944
self .frame = JaxprStackFrame (debug_info , auto_dce )
1945
1945
self .parent_trace = parent_trace
1946
1946
1947
+ # TODO(dougalm): we might be able to remove this since the refcounting should be doing it for us
1947
1948
def invalidate (self ):
1948
1949
# avoid cyclic refs
1949
1950
self .frame .constid_to_tracer = {}
@@ -1964,7 +1965,7 @@ def var_to_tracer(self, var, source_info, parent=None):
1964
1965
aval = var .aval
1965
1966
if aval .has_qdd :
1966
1967
aval = core .AvalQDD (aval , var .initial_qdd )
1967
- return DynamicJaxprTracer (self , aval , var , parent , source_info )
1968
+ return DynamicJaxprTracer (self , aval , var , source_info , parent )
1968
1969
1969
1970
def new_arg (self , aval , source_info : SourceInfo ):
1970
1971
var = self .frame .newvar (aval )
@@ -2022,7 +2023,7 @@ def _new_const(self, aval, c, source_info: SourceInfo) -> DynamicJaxprTracer:
2022
2023
return tracer
2023
2024
2024
2025
def finalize_const (self , var , constid ):
2025
- self .frame .constvar_to_val .pop (var )
2026
+ self .frame .constvar_to_val .pop (var , None )
2026
2027
2027
2028
def get_const (self , tracer ) -> Any :
2028
2029
atom = tracer .val
0 commit comments