203
203
mutable struct TearingState{T <: AbstractSystem } <: AbstractTearingState{T}
204
204
""" The system of equations."""
205
205
sys:: T
206
+ original_eqs:: Vector{Equation}
206
207
""" The set of variables of the system."""
207
208
fullvars:: Vector{BasicSymbolic}
208
209
structure:: SystemStructure
213
214
TransformationState (sys:: AbstractSystem ) = TearingState (sys)
214
215
function system_subset (ts:: TearingState , ieqs:: Vector{Int} )
215
216
eqs = equations (ts)
217
+ @set! ts. original_eqs = ts. original_eqs[ieqs]
216
218
@set! ts. sys. eqs = eqs[ieqs]
217
219
@set! ts. structure = system_subset (ts. structure, ieqs)
218
220
ts
@@ -274,8 +276,9 @@ function TearingState(sys; quick_cancel = false, check = true, sort_eqs = true)
274
276
sys = process_parameter_equations (sys)
275
277
ivs = independent_variables (sys)
276
278
iv = length (ivs) == 1 ? ivs[1 ] : nothing
277
- # flatten array equations
278
- eqs = flatten_equations (equations (sys))
279
+ # scalarize array equations, without scalarizing arguments to registered functions
280
+ original_eqs = flatten_equations (copy (equations (sys)))
281
+ eqs = copy (original_eqs)
279
282
neqs = length (eqs)
280
283
param_derivative_map = Dict {BasicSymbolic, Any} ()
281
284
# * Scalarize unknowns
@@ -415,6 +418,7 @@ function TearingState(sys; quick_cancel = false, check = true, sort_eqs = true)
415
418
end
416
419
end
417
420
eqs = eqs[eqs_to_retain]
421
+ original_eqs = original_eqs[eqs_to_retain]
418
422
neqs = length (eqs)
419
423
symbolic_incidence = symbolic_incidence[eqs_to_retain]
420
424
@@ -423,6 +427,7 @@ function TearingState(sys; quick_cancel = false, check = true, sort_eqs = true)
423
427
# depending on order due to NP-completeness of tearing.
424
428
sortidxs = Base. sortperm (eqs, by = string)
425
429
eqs = eqs[sortidxs]
430
+ original_eqs = original_eqs[sortidxs]
426
431
symbolic_incidence = symbolic_incidence[sortidxs]
427
432
end
428
433
@@ -513,7 +518,7 @@ function TearingState(sys; quick_cancel = false, check = true, sort_eqs = true)
513
518
514
519
eq_to_diff = DiffGraph (nsrcs (graph))
515
520
516
- ts = TearingState (sys, fullvars,
521
+ ts = TearingState (sys, original_eqs, fullvars,
517
522
SystemStructure (complete (var_to_diff), complete (eq_to_diff),
518
523
complete (graph), nothing , var_types, false ),
519
524
Any[], param_derivative_map)
@@ -696,6 +701,22 @@ function Base.show(io::IO, mime::MIME"text/plain", ms::MatchedSystemStructure)
696
701
printstyled (io, " SelectedState" )
697
702
end
698
703
704
+ function make_eqs_zero_equals! (ts:: TearingState )
705
+ neweqs = map (enumerate (get_eqs (ts. sys))) do kvp
706
+ i, eq = kvp
707
+ isalgeq = true
708
+ for j in 𝑠neighbors (ts. structure. graph, i)
709
+ isalgeq &= invview (ts. structure. var_to_diff)[j] === nothing
710
+ end
711
+ if isalgeq
712
+ return 0 ~ eq. rhs - eq. lhs
713
+ else
714
+ return eq
715
+ end
716
+ end
717
+ copyto! (get_eqs (ts. sys), neweqs)
718
+ end
719
+
699
720
function mtkcompile! (state:: TearingState ; simplify = false ,
700
721
check_consistency = true , fully_determined = true , warn_initialize_determined = true ,
701
722
inputs = Any[], outputs = Any[],
@@ -722,6 +743,7 @@ function mtkcompile!(state::TearingState; simplify = false,
722
743
""" ))
723
744
end
724
745
if length (tss) > 1
746
+ make_eqs_zero_equals! (tss[continuous_id])
725
747
# simplify as normal
726
748
sys = _mtkcompile! (tss[continuous_id]; simplify,
727
749
inputs = [inputs; clocked_inputs[continuous_id]], outputs, disturbance_inputs,
0 commit comments