Skip to content

Commit 2022245

Browse files
feat: retain original equations of the system in TearingState
1 parent 56ca2b3 commit 2022245

File tree

2 files changed

+27
-5
lines changed

2 files changed

+27
-5
lines changed

src/systems/systems.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,6 @@ function __mtkcompile(sys::AbstractSystem; simplify = false,
8080

8181
@unpack structure, fullvars = state
8282
@unpack graph, var_to_diff, var_types = structure
83-
eqs = equations(state)
8483
brown_vars = Int[]
8584
new_idxs = zeros(Int, length(var_types))
8685
idx = 0
@@ -98,7 +97,8 @@ function __mtkcompile(sys::AbstractSystem; simplify = false,
9897
Is = Int[]
9998
Js = Int[]
10099
vals = Num[]
101-
new_eqs = copy(eqs)
100+
make_eqs_zero_equals!(state)
101+
new_eqs = copy(equations(state))
102102
dvar2eq = Dict{Any, Int}()
103103
for (v, dv) in enumerate(var_to_diff)
104104
dv === nothing && continue

src/systems/systemstructure.jl

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,7 @@ end
203203
mutable struct TearingState{T <: AbstractSystem} <: AbstractTearingState{T}
204204
"""The system of equations."""
205205
sys::T
206+
original_eqs::Vector{Equation}
206207
"""The set of variables of the system."""
207208
fullvars::Vector{BasicSymbolic}
208209
structure::SystemStructure
@@ -213,6 +214,7 @@ end
213214
TransformationState(sys::AbstractSystem) = TearingState(sys)
214215
function system_subset(ts::TearingState, ieqs::Vector{Int})
215216
eqs = equations(ts)
217+
@set! ts.original_eqs = ts.original_eqs[ieqs]
216218
@set! ts.sys.eqs = eqs[ieqs]
217219
@set! ts.structure = system_subset(ts.structure, ieqs)
218220
ts
@@ -274,8 +276,9 @@ function TearingState(sys; quick_cancel = false, check = true, sort_eqs = true)
274276
sys = process_parameter_equations(sys)
275277
ivs = independent_variables(sys)
276278
iv = length(ivs) == 1 ? ivs[1] : nothing
277-
# flatten array equations
278-
eqs = flatten_equations(equations(sys))
279+
# scalarize array equations, without scalarizing arguments to registered functions
280+
original_eqs = flatten_equations(copy(equations(sys)))
281+
eqs = copy(original_eqs)
279282
neqs = length(eqs)
280283
param_derivative_map = Dict{BasicSymbolic, Any}()
281284
# * Scalarize unknowns
@@ -415,6 +418,7 @@ function TearingState(sys; quick_cancel = false, check = true, sort_eqs = true)
415418
end
416419
end
417420
eqs = eqs[eqs_to_retain]
421+
original_eqs = original_eqs[eqs_to_retain]
418422
neqs = length(eqs)
419423
symbolic_incidence = symbolic_incidence[eqs_to_retain]
420424

@@ -423,6 +427,7 @@ function TearingState(sys; quick_cancel = false, check = true, sort_eqs = true)
423427
# depending on order due to NP-completeness of tearing.
424428
sortidxs = Base.sortperm(eqs, by = string)
425429
eqs = eqs[sortidxs]
430+
original_eqs = original_eqs[sortidxs]
426431
symbolic_incidence = symbolic_incidence[sortidxs]
427432
end
428433

@@ -513,7 +518,7 @@ function TearingState(sys; quick_cancel = false, check = true, sort_eqs = true)
513518

514519
eq_to_diff = DiffGraph(nsrcs(graph))
515520

516-
ts = TearingState(sys, fullvars,
521+
ts = TearingState(sys, original_eqs, fullvars,
517522
SystemStructure(complete(var_to_diff), complete(eq_to_diff),
518523
complete(graph), nothing, var_types, false),
519524
Any[], param_derivative_map)
@@ -696,6 +701,22 @@ function Base.show(io::IO, mime::MIME"text/plain", ms::MatchedSystemStructure)
696701
printstyled(io, " SelectedState")
697702
end
698703

704+
function make_eqs_zero_equals!(ts::TearingState)
705+
neweqs = map(enumerate(get_eqs(ts.sys))) do kvp
706+
i, eq = kvp
707+
isalgeq = true
708+
for j in 𝑠neighbors(ts.structure.graph, i)
709+
isalgeq &= invview(ts.structure.var_to_diff)[j] === nothing
710+
end
711+
if isalgeq
712+
return 0 ~ eq.rhs - eq.lhs
713+
else
714+
return eq
715+
end
716+
end
717+
copyto!(get_eqs(ts.sys), neweqs)
718+
end
719+
699720
function mtkcompile!(state::TearingState; simplify = false,
700721
check_consistency = true, fully_determined = true, warn_initialize_determined = true,
701722
inputs = Any[], outputs = Any[],
@@ -722,6 +743,7 @@ function mtkcompile!(state::TearingState; simplify = false,
722743
"""))
723744
end
724745
if length(tss) > 1
746+
make_eqs_zero_equals!(tss[continuous_id])
725747
# simplify as normal
726748
sys = _mtkcompile!(tss[continuous_id]; simplify,
727749
inputs = [inputs; clocked_inputs[continuous_id]], outputs, disturbance_inputs,

0 commit comments

Comments
 (0)