Skip to content

Commit c7e923d

Browse files
fix: remove CSE hack
1 parent b20fc7f commit c7e923d

File tree

2 files changed

+12
-81
lines changed

2 files changed

+12
-81
lines changed

src/structural_transformation/symbolics_tearing.jl

Lines changed: 10 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -708,7 +708,7 @@ Update the system equations, unknowns, and observables after simplification.
708708
"""
709709
function update_simplified_system!(
710710
state::TearingState, neweqs, solved_eqs, dummy_sub, var_eq_matching, extra_unknowns;
711-
cse_hack = true, array_hack = true)
711+
array_hack = true)
712712
@unpack solvable_graph, var_to_diff, eq_to_diff, graph = state.structure
713713
diff_to_var = invview(var_to_diff)
714714

@@ -732,8 +732,8 @@ function update_simplified_system!(
732732
unknowns = [unknowns; extra_unknowns]
733733
@set! sys.unknowns = unknowns
734734

735-
obs, subeqs, deps = cse_and_array_hacks(
736-
sys, obs, solved_eqs, unknowns, neweqs; cse = cse_hack, array = array_hack)
735+
obs, subeqs, deps = array_var_hack(
736+
sys, obs, solved_eqs, unknowns, neweqs; array = array_hack)
737737

738738
@set! sys.eqs = neweqs
739739
@set! sys.observed = obs
@@ -775,7 +775,7 @@ appear in the system. Algebraic variables are variables that are not
775775
differential variables.
776776
"""
777777
function tearing_reassemble(state::TearingState, var_eq_matching,
778-
full_var_eq_matching = nothing; simplify = false, mm = nothing, cse_hack = true, array_hack = true)
778+
full_var_eq_matching = nothing; simplify = false, mm = nothing, array_hack = true)
779779
extra_vars = Int[]
780780
if full_var_eq_matching !== nothing
781781
for v in 𝑑vertices(state.structure.graph)
@@ -811,21 +811,14 @@ function tearing_reassemble(state::TearingState, var_eq_matching,
811811
state, var_eq_matching, eq_ordering, var_ordering, nelim_eq, nelim_var)
812812

813813
sys = update_simplified_system!(state, neweqs, solved_eqs, dummy_sub, var_eq_matching,
814-
extra_unknowns; cse_hack, array_hack)
814+
extra_unknowns; array_hack)
815815

816816
@set! state.sys = sys
817817
@set! sys.tearing_state = state
818818
return invalidate_cache!(sys)
819819
end
820820

821821
"""
822-
# HACK 1
823-
824-
Since we don't support array equations, any equation of the sort `x[1:n] ~ f(...)[1:n]`
825-
gets turned into `x[1] ~ f(...)[1], x[2] ~ f(...)[2]`. Repeatedly calling `f` gets
826-
_very_ expensive. this hack performs a limited form of CSE specifically for this case to
827-
avoid the unnecessary cost. This and the below hack are implemented simultaneously
828-
829822
# HACK 2
830823
831824
Add equations for array observed variables. If `p[i] ~ (...)` are equations, add an
@@ -834,12 +827,7 @@ if all `p[i]` are present and the unscalarized form is used in any equation (obs
834827
not) we first count the number of times the scalarized form of each observed variable
835828
occurs in observed equations (and unknowns if it's split).
836829
"""
837-
function cse_and_array_hacks(sys, obs, subeqs, unknowns, neweqs; cse = true, array = true)
838-
# HACK 1
839-
# mapping of rhs to temporary CSE variable
840-
# `f(...) => tmpvar` in above example
841-
rhs_to_tempvar = Dict()
842-
830+
function array_var_hack(sys, obs, subeqs, unknowns, neweqs; array = true)
843831
# HACK 2
844832
# map of array observed variable (unscalarized) to number of its
845833
# scalarized terms that appear in observed equations
@@ -851,36 +839,6 @@ function cse_and_array_hacks(sys, obs, subeqs, unknowns, neweqs; cse = true, arr
851839
rhs = eq.rhs
852840
vars!(all_vars, rhs)
853841

854-
# HACK 1
855-
if cse && is_getindexed_array(rhs)
856-
rhs_arr = arguments(rhs)[1]
857-
iscall(rhs_arr) && operation(rhs_arr) isa Symbolics.Operator && continue
858-
if !haskey(rhs_to_tempvar, rhs_arr)
859-
tempvar = gensym(Symbol(lhs))
860-
N = length(rhs_arr)
861-
tempvar = unwrap(Symbolics.variable(
862-
tempvar; T = Symbolics.symtype(rhs_arr)))
863-
tempvar = setmetadata(
864-
tempvar, Symbolics.ArrayShapeCtx, Symbolics.shape(rhs_arr))
865-
tempeq = tempvar ~ rhs_arr
866-
rhs_to_tempvar[rhs_arr] = tempvar
867-
push!(obs, tempeq)
868-
push!(subeqs, tempeq)
869-
end
870-
871-
# getindex_wrapper is used because `observed2graph` treats `x` and `x[i]` as different,
872-
# so it doesn't find a dependency between this equation and `tempvar ~ rhs_arr`
873-
# which fails the topological sort
874-
neweq = lhs ~ getindex_wrapper(
875-
rhs_to_tempvar[rhs_arr], Tuple(arguments(rhs)[2:end]))
876-
obs[i] = neweq
877-
subeqi = findfirst(isequal(eq), subeqs)
878-
if subeqi !== nothing
879-
subeqs[subeqi] = neweq
880-
end
881-
end
882-
# end HACK 1
883-
884842
array || continue
885843
iscall(lhs) || continue
886844
operation(lhs) === getindex || continue
@@ -891,33 +849,6 @@ function cse_and_array_hacks(sys, obs, subeqs, unknowns, neweqs; cse = true, arr
891849
continue
892850
end
893851

894-
# Also do CSE for `equations(sys)`
895-
if cse
896-
for (i, eq) in enumerate(neweqs)
897-
(; lhs, rhs) = eq
898-
is_getindexed_array(rhs) || continue
899-
rhs_arr = arguments(rhs)[1]
900-
if !haskey(rhs_to_tempvar, rhs_arr)
901-
tempvar = gensym(Symbol(lhs))
902-
N = length(rhs_arr)
903-
tempvar = unwrap(Symbolics.variable(
904-
tempvar; T = Symbolics.symtype(rhs_arr)))
905-
tempvar = setmetadata(
906-
tempvar, Symbolics.ArrayShapeCtx, Symbolics.shape(rhs_arr))
907-
vars!(all_vars, rhs_arr)
908-
tempeq = tempvar ~ rhs_arr
909-
rhs_to_tempvar[rhs_arr] = tempvar
910-
push!(obs, tempeq)
911-
push!(subeqs, tempeq)
912-
end
913-
# don't need getindex_wrapper, but do it anyway to know that this
914-
# hack took place
915-
neweq = lhs ~ getindex_wrapper(
916-
rhs_to_tempvar[rhs_arr], Tuple(arguments(rhs)[2:end]))
917-
neweqs[i] = neweq
918-
end
919-
end
920-
921852
# count variables in unknowns if they are scalarized forms of variables
922853
# also present as observed. e.g. if `x[1]` is an unknown and `x[2] ~ (..)`
923854
# is an observed equation.
@@ -1007,10 +938,10 @@ new residual equations after tearing. End users are encouraged to call [`structu
1007938
instead, which calls this function internally.
1008939
"""
1009940
function tearing(sys::AbstractSystem, state = TearingState(sys); mm = nothing,
1010-
simplify = false, cse_hack = true, array_hack = true, kwargs...)
941+
simplify = false, array_hack = true, kwargs...)
1011942
var_eq_matching, full_var_eq_matching = tearing(state)
1012943
invalidate_cache!(tearing_reassemble(
1013-
state, var_eq_matching, full_var_eq_matching; mm, simplify, cse_hack, array_hack))
944+
state, var_eq_matching, full_var_eq_matching; mm, simplify, array_hack))
1014945
end
1015946

1016947
"""
@@ -1032,7 +963,7 @@ Perform index reduction and use the dummy derivative technique to ensure that
1032963
the system is balanced.
1033964
"""
1034965
function dummy_derivative(sys, state = TearingState(sys); simplify = false,
1035-
mm = nothing, cse_hack = true, array_hack = true, kwargs...)
966+
mm = nothing, array_hack = true, kwargs...)
1036967
jac = let state = state
1037968
(eqs, vars) -> begin
1038969
symeqs = EquationsView(state)[eqs]
@@ -1056,5 +987,5 @@ function dummy_derivative(sys, state = TearingState(sys); simplify = false,
1056987
end
1057988
var_eq_matching = dummy_derivative_graph!(state, jac; state_priority,
1058989
kwargs...)
1059-
tearing_reassemble(state, var_eq_matching; simplify, mm, cse_hack, array_hack)
990+
tearing_reassemble(state, var_eq_matching; simplify, mm, array_hack)
1060991
end

test/structural_transformation/utils.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ end
121121
@named sys = ODESystem(
122122
[D(x) ~ z[1] + z[2] + foo(z)[1], y[1] ~ 2t, y[2] ~ 3t, z ~ foo(y)], t)
123123

124-
sys1 = structural_simplify(sys; cse_hack = false)
124+
sys1 = structural_simplify(sys)
125125
@test length(observed(sys1)) == 6
126126
@test !any(observed(sys1)) do eq
127127
iscall(eq.rhs) &&
@@ -142,7 +142,7 @@ end
142142
@named sys = ODESystem(
143143
[D(x) ~ z[1] + z[2] + foo(z)[1] + w, y[1] ~ 2t, y[2] ~ 3t, z ~ foo(y)], t)
144144

145-
sys1 = structural_simplify(sys; cse_hack = false, fully_determined = false)
145+
sys1 = structural_simplify(sys; fully_determined = false)
146146
@test length(observed(sys1)) == 6
147147
@test !any(observed(sys1)) do eq
148148
iscall(eq.rhs) &&

0 commit comments

Comments
 (0)