Skip to content

Commit 12ca5e8

Browse files
authored
Merge pull request #731 from hexaeder/hw/cse_varnames
name CSE states in a reproducible way
2 parents 5bfae77 + 100601e commit 12ca5e8

File tree

1 file changed

+21
-7
lines changed

1 file changed

+21
-7
lines changed

src/code.jl

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -737,7 +737,16 @@ end
737737

738738
### Common subexprssion evaluation
739739

740-
@inline newsym(::Type{T}) where T = Sym{T}(gensym("cse"))
740+
"""
741+
newsym!(state::CSEState, ::Type{T})
742+
743+
Generates new symbol of type `T` with unique name in `state`.
744+
"""
745+
@inline function newsym!(state, ::Type{T}) where T
746+
name = "##cse#$(state.varid[])"
747+
state.varid[] += 1
748+
Sym{T}(Symbol(name))
749+
end
741750

742751
"""
743752
$(TYPEDSIGNATURES)
@@ -769,11 +778,16 @@ struct CSEState
769778
A mapping of symbolic expression to the LHS in `sorted_exprs` that computes it.
770779
"""
771780
visited::IdDict{Any, Any}
781+
"""
782+
Integer counter, used to generate unique names for intermediate variables.
783+
"""
784+
varid::Ref{Int}
772785
end
773786

774-
CSEState() = CSEState(Union{Assignment, DestructuredArgs}[], IdDict())
787+
CSEState() = CSEState(Union{Assignment, DestructuredArgs}[], IdDict(), Ref(1))
775788

776-
Base.copy(x::CSEState) = CSEState(copy(x.sorted_exprs), copy(x.visited))
789+
# the copy still references the same `varid` Ref to work in nested scopes
790+
Base.copy(x::CSEState) = CSEState(copy(x.sorted_exprs), copy(x.visited), x.varid)
777791

778792
"""
779793
$(TYPEDSIGNATURES)
@@ -861,13 +875,13 @@ function cse!(expr::Symbolic, state::CSEState)
861875
(_is_array_of_symbolics(arg) || _is_tuple_of_symbolics(arg))
862876
if arg isa Tuple
863877
new_arg = cse!(MakeTuple(arg), state)
864-
sym = newsym(Tuple{symtype.(arg)...})
878+
sym = newsym!(state, Tuple{symtype.(arg)...})
865879
elseif issparse(arg)
866880
new_arg = cse!(MakeSparseArray(arg), state)
867-
sym = newsym(AbstractSparseArray{symtype(eltype(arg)), indextype(arg), ndims(arg)})
881+
sym = newsym!(state, AbstractSparseArray{symtype(eltype(arg)), indextype(arg), ndims(arg)})
868882
else
869883
new_arg = cse!(MakeArray(arg, typeof(arg)), state)
870-
sym = newsym(AbstractArray{symtype(eltype(arg)), ndims(arg)})
884+
sym = newsym!(state, AbstractArray{symtype(eltype(arg)), ndims(arg)})
871885
end
872886
push!(state.sorted_exprs, sym new_arg)
873887
state.visited[arg] = sym
@@ -878,7 +892,7 @@ function cse!(expr::Symbolic, state::CSEState)
878892
# use `term` instead of `maketerm` because we only care about the operation being performed
879893
# and not the representation. This avoids issues with `newsym` symbols not having sizes, etc.
880894
new_expr = term(operation(expr), args...; type = symtype(expr))
881-
sym = newsym(symtype(new_expr))
895+
sym = newsym!(state, symtype(new_expr))
882896
push!(state.sorted_exprs, sym new_expr)
883897
return sym
884898
end

0 commit comments

Comments
 (0)