737
737
738
738
# ## Common subexprssion evaluation
739
739
740
- @inline newsym (:: Type{T} ) where T = Sym {T} (gensym (" cse" ))
740
+ """
741
+ newsym!(state::CSEState, ::Type{T})
742
+
743
+ Generates new symbol of type `T` with unique name in `state`.
744
+ """
745
+ @inline function newsym! (state, :: Type{T} ) where T
746
+ name = " ##cse#$(state. varid[]) "
747
+ state. varid[] += 1
748
+ Sym {T} (Symbol (name))
749
+ end
741
750
742
751
"""
743
752
$(TYPEDSIGNATURES)
@@ -769,11 +778,15 @@ struct CSEState
769
778
A mapping of symbolic expression to the LHS in `sorted_exprs` that computes it.
770
779
"""
771
780
visited:: IdDict{Any, Any}
781
+ """
782
+ Integer counter, used to generate unique names for intermediate variables.
783
+ """
784
+ varid:: Ref{Int}
772
785
end
773
786
774
- CSEState () = CSEState (Union{Assignment, DestructuredArgs}[], IdDict ())
787
+ CSEState () = CSEState (Union{Assignment, DestructuredArgs}[], IdDict (), Ref ( 1 ) )
775
788
776
- Base. copy (x:: CSEState ) = CSEState (copy (x. sorted_exprs), copy (x. visited))
789
+ Base. copy (x:: CSEState ) = CSEState (copy (x. sorted_exprs), copy (x. visited), Ref (x . varid[]) )
777
790
778
791
"""
779
792
$(TYPEDSIGNATURES)
@@ -860,13 +873,13 @@ function cse!(expr::Symbolic, state::CSEState)
860
873
if arg isa Union{Tuple, AbstractArray}
861
874
if arg isa Tuple
862
875
new_arg = cse! (MakeTuple (arg), state)
863
- sym = newsym ( Tuple{symtype .(arg)... })
876
+ sym = newsym! (state, Tuple{symtype .(arg)... })
864
877
elseif issparse (arg)
865
878
new_arg = cse! (MakeSparseArray (arg), state)
866
- sym = newsym ( AbstractSparseArray{symtype (eltype (arg)), indextype (arg), ndims (arg)})
879
+ sym = newsym! (state, AbstractSparseArray{symtype (eltype (arg)), indextype (arg), ndims (arg)})
867
880
else
868
881
new_arg = cse! (MakeArray (arg, typeof (arg)), state)
869
- sym = newsym ( AbstractArray{symtype (eltype (arg)), ndims (arg)})
882
+ sym = newsym! (state, AbstractArray{symtype (eltype (arg)), ndims (arg)})
870
883
end
871
884
push! (state. sorted_exprs, sym ← new_arg)
872
885
state. visited[arg] = sym
@@ -877,7 +890,7 @@ function cse!(expr::Symbolic, state::CSEState)
877
890
# use `term` instead of `maketerm` because we only care about the operation being performed
878
891
# and not the representation. This avoids issues with `newsym` symbols not having sizes, etc.
879
892
new_expr = term (operation (expr), args... ; type = symtype (expr))
880
- sym = newsym ( symtype (new_expr))
893
+ sym = newsym! (state, symtype (new_expr))
881
894
push! (state. sorted_exprs, sym ← new_expr)
882
895
return sym
883
896
end
0 commit comments