@@ -708,7 +708,7 @@ Update the system equations, unknowns, and observables after simplification.
708
708
"""
709
709
function update_simplified_system! (
710
710
state:: TearingState , neweqs, solved_eqs, dummy_sub, var_eq_matching, extra_unknowns;
711
- cse_hack = true , array_hack = true )
711
+ array_hack = true )
712
712
@unpack solvable_graph, var_to_diff, eq_to_diff, graph = state. structure
713
713
diff_to_var = invview (var_to_diff)
714
714
@@ -732,8 +732,8 @@ function update_simplified_system!(
732
732
unknowns = [unknowns; extra_unknowns]
733
733
@set! sys. unknowns = unknowns
734
734
735
- obs, subeqs, deps = cse_and_array_hacks (
736
- sys, obs, solved_eqs, unknowns, neweqs; cse = cse_hack, array = array_hack)
735
+ obs, subeqs, deps = array_var_hack (
736
+ sys, obs, solved_eqs, unknowns, neweqs; array = array_hack)
737
737
738
738
@set! sys. eqs = neweqs
739
739
@set! sys. observed = obs
@@ -775,7 +775,7 @@ appear in the system. Algebraic variables are variables that are not
775
775
differential variables.
776
776
"""
777
777
function tearing_reassemble (state:: TearingState , var_eq_matching,
778
- full_var_eq_matching = nothing ; simplify = false , mm = nothing , cse_hack = true , array_hack = true )
778
+ full_var_eq_matching = nothing ; simplify = false , mm = nothing , array_hack = true )
779
779
extra_vars = Int[]
780
780
if full_var_eq_matching != = nothing
781
781
for v in 𝑑vertices (state. structure. graph)
@@ -811,21 +811,14 @@ function tearing_reassemble(state::TearingState, var_eq_matching,
811
811
state, var_eq_matching, eq_ordering, var_ordering, nelim_eq, nelim_var)
812
812
813
813
sys = update_simplified_system! (state, neweqs, solved_eqs, dummy_sub, var_eq_matching,
814
- extra_unknowns; cse_hack, array_hack)
814
+ extra_unknowns; array_hack)
815
815
816
816
@set! state. sys = sys
817
817
@set! sys. tearing_state = state
818
818
return invalidate_cache! (sys)
819
819
end
820
820
821
821
"""
822
- # HACK 1
823
-
824
- Since we don't support array equations, any equation of the sort `x[1:n] ~ f(...)[1:n]`
825
- gets turned into `x[1] ~ f(...)[1], x[2] ~ f(...)[2]`. Repeatedly calling `f` gets
826
- _very_ expensive. this hack performs a limited form of CSE specifically for this case to
827
- avoid the unnecessary cost. This and the below hack are implemented simultaneously
828
-
829
822
# HACK 2
830
823
831
824
Add equations for array observed variables. If `p[i] ~ (...)` are equations, add an
@@ -834,12 +827,7 @@ if all `p[i]` are present and the unscalarized form is used in any equation (obs
834
827
not) we first count the number of times the scalarized form of each observed variable
835
828
occurs in observed equations (and unknowns if it's split).
836
829
"""
837
- function cse_and_array_hacks (sys, obs, subeqs, unknowns, neweqs; cse = true , array = true )
838
- # HACK 1
839
- # mapping of rhs to temporary CSE variable
840
- # `f(...) => tmpvar` in above example
841
- rhs_to_tempvar = Dict ()
842
-
830
+ function array_var_hack (sys, obs, subeqs, unknowns, neweqs; array = true )
843
831
# HACK 2
844
832
# map of array observed variable (unscalarized) to number of its
845
833
# scalarized terms that appear in observed equations
@@ -851,36 +839,6 @@ function cse_and_array_hacks(sys, obs, subeqs, unknowns, neweqs; cse = true, arr
851
839
rhs = eq. rhs
852
840
vars! (all_vars, rhs)
853
841
854
- # HACK 1
855
- if cse && is_getindexed_array (rhs)
856
- rhs_arr = arguments (rhs)[1 ]
857
- iscall (rhs_arr) && operation (rhs_arr) isa Symbolics. Operator && continue
858
- if ! haskey (rhs_to_tempvar, rhs_arr)
859
- tempvar = gensym (Symbol (lhs))
860
- N = length (rhs_arr)
861
- tempvar = unwrap (Symbolics. variable (
862
- tempvar; T = Symbolics. symtype (rhs_arr)))
863
- tempvar = setmetadata (
864
- tempvar, Symbolics. ArrayShapeCtx, Symbolics. shape (rhs_arr))
865
- tempeq = tempvar ~ rhs_arr
866
- rhs_to_tempvar[rhs_arr] = tempvar
867
- push! (obs, tempeq)
868
- push! (subeqs, tempeq)
869
- end
870
-
871
- # getindex_wrapper is used because `observed2graph` treats `x` and `x[i]` as different,
872
- # so it doesn't find a dependency between this equation and `tempvar ~ rhs_arr`
873
- # which fails the topological sort
874
- neweq = lhs ~ getindex_wrapper (
875
- rhs_to_tempvar[rhs_arr], Tuple (arguments (rhs)[2 : end ]))
876
- obs[i] = neweq
877
- subeqi = findfirst (isequal (eq), subeqs)
878
- if subeqi != = nothing
879
- subeqs[subeqi] = neweq
880
- end
881
- end
882
- # end HACK 1
883
-
884
842
array || continue
885
843
iscall (lhs) || continue
886
844
operation (lhs) === getindex || continue
@@ -891,33 +849,6 @@ function cse_and_array_hacks(sys, obs, subeqs, unknowns, neweqs; cse = true, arr
891
849
continue
892
850
end
893
851
894
- # Also do CSE for `equations(sys)`
895
- if cse
896
- for (i, eq) in enumerate (neweqs)
897
- (; lhs, rhs) = eq
898
- is_getindexed_array (rhs) || continue
899
- rhs_arr = arguments (rhs)[1 ]
900
- if ! haskey (rhs_to_tempvar, rhs_arr)
901
- tempvar = gensym (Symbol (lhs))
902
- N = length (rhs_arr)
903
- tempvar = unwrap (Symbolics. variable (
904
- tempvar; T = Symbolics. symtype (rhs_arr)))
905
- tempvar = setmetadata (
906
- tempvar, Symbolics. ArrayShapeCtx, Symbolics. shape (rhs_arr))
907
- vars! (all_vars, rhs_arr)
908
- tempeq = tempvar ~ rhs_arr
909
- rhs_to_tempvar[rhs_arr] = tempvar
910
- push! (obs, tempeq)
911
- push! (subeqs, tempeq)
912
- end
913
- # don't need getindex_wrapper, but do it anyway to know that this
914
- # hack took place
915
- neweq = lhs ~ getindex_wrapper (
916
- rhs_to_tempvar[rhs_arr], Tuple (arguments (rhs)[2 : end ]))
917
- neweqs[i] = neweq
918
- end
919
- end
920
-
921
852
# count variables in unknowns if they are scalarized forms of variables
922
853
# also present as observed. e.g. if `x[1]` is an unknown and `x[2] ~ (..)`
923
854
# is an observed equation.
@@ -1007,10 +938,10 @@ new residual equations after tearing. End users are encouraged to call [`structu
1007
938
instead, which calls this function internally.
1008
939
"""
1009
940
function tearing (sys:: AbstractSystem , state = TearingState (sys); mm = nothing ,
1010
- simplify = false , cse_hack = true , array_hack = true , kwargs... )
941
+ simplify = false , array_hack = true , kwargs... )
1011
942
var_eq_matching, full_var_eq_matching = tearing (state)
1012
943
invalidate_cache! (tearing_reassemble (
1013
- state, var_eq_matching, full_var_eq_matching; mm, simplify, cse_hack, array_hack))
944
+ state, var_eq_matching, full_var_eq_matching; mm, simplify, array_hack))
1014
945
end
1015
946
1016
947
"""
@@ -1032,7 +963,7 @@ Perform index reduction and use the dummy derivative technique to ensure that
1032
963
the system is balanced.
1033
964
"""
1034
965
function dummy_derivative (sys, state = TearingState (sys); simplify = false ,
1035
- mm = nothing , cse_hack = true , array_hack = true , kwargs... )
966
+ mm = nothing , array_hack = true , kwargs... )
1036
967
jac = let state = state
1037
968
(eqs, vars) -> begin
1038
969
symeqs = EquationsView (state)[eqs]
@@ -1056,5 +987,5 @@ function dummy_derivative(sys, state = TearingState(sys); simplify = false,
1056
987
end
1057
988
var_eq_matching = dummy_derivative_graph! (state, jac; state_priority,
1058
989
kwargs... )
1059
- tearing_reassemble (state, var_eq_matching; simplify, mm, cse_hack, array_hack)
990
+ tearing_reassemble (state, var_eq_matching; simplify, mm, array_hack)
1060
991
end
0 commit comments