737
737
738
738
# ## Common subexprssion evaluation
739
739
740
- @inline newsym (:: Type{T} ) where T = Sym {T} (gensym (" cse" ))
740
+ """
741
+ newsym!(state::CSEState, ::Type{T})
742
+
743
+ Generates new symbol of type `T` with unique name in `state`.
744
+ """
745
+ @inline function newsym! (state, :: Type{T} ) where T
746
+ name = " ##cse#$(state. varid[]) "
747
+ state. varid[] += 1
748
+ Sym {T} (Symbol (name))
749
+ end
741
750
742
751
"""
743
752
$(TYPEDSIGNATURES)
@@ -769,11 +778,16 @@ struct CSEState
769
778
A mapping of symbolic expression to the LHS in `sorted_exprs` that computes it.
770
779
"""
771
780
visited:: IdDict{Any, Any}
781
+ """
782
+ Integer counter, used to generate unique names for intermediate variables.
783
+ """
784
+ varid:: Ref{Int}
772
785
end
773
786
774
- CSEState () = CSEState (Union{Assignment, DestructuredArgs}[], IdDict ())
787
+ CSEState () = CSEState (Union{Assignment, DestructuredArgs}[], IdDict (), Ref ( 1 ) )
775
788
776
- Base. copy (x:: CSEState ) = CSEState (copy (x. sorted_exprs), copy (x. visited))
789
+ # the copy still references the same `varid` Ref to work in nested scopes
790
+ Base. copy (x:: CSEState ) = CSEState (copy (x. sorted_exprs), copy (x. visited), x. varid)
777
791
778
792
"""
779
793
$(TYPEDSIGNATURES)
@@ -861,13 +875,13 @@ function cse!(expr::Symbolic, state::CSEState)
861
875
(_is_array_of_symbolics (arg) || _is_tuple_of_symbolics (arg))
862
876
if arg isa Tuple
863
877
new_arg = cse! (MakeTuple (arg), state)
864
- sym = newsym ( Tuple{symtype .(arg)... })
878
+ sym = newsym! (state, Tuple{symtype .(arg)... })
865
879
elseif issparse (arg)
866
880
new_arg = cse! (MakeSparseArray (arg), state)
867
- sym = newsym ( AbstractSparseArray{symtype (eltype (arg)), indextype (arg), ndims (arg)})
881
+ sym = newsym! (state, AbstractSparseArray{symtype (eltype (arg)), indextype (arg), ndims (arg)})
868
882
else
869
883
new_arg = cse! (MakeArray (arg, typeof (arg)), state)
870
- sym = newsym ( AbstractArray{symtype (eltype (arg)), ndims (arg)})
884
+ sym = newsym! (state, AbstractArray{symtype (eltype (arg)), ndims (arg)})
871
885
end
872
886
push! (state. sorted_exprs, sym ← new_arg)
873
887
state. visited[arg] = sym
@@ -878,7 +892,7 @@ function cse!(expr::Symbolic, state::CSEState)
878
892
# use `term` instead of `maketerm` because we only care about the operation being performed
879
893
# and not the representation. This avoids issues with `newsym` symbols not having sizes, etc.
880
894
new_expr = term (operation (expr), args... ; type = symtype (expr))
881
- sym = newsym ( symtype (new_expr))
895
+ sym = newsym! (state, symtype (new_expr))
882
896
push! (state. sorted_exprs, sym ← new_expr)
883
897
return sym
884
898
end
0 commit comments