Skip to content

Commit fbef1a8

Browse files
feat: subset variables appropriately in clock inference
1 parent aeefc8a commit fbef1a8

File tree

2 files changed

+23
-11
lines changed

2 files changed

+23
-11
lines changed

src/systems/clock_inference.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -199,8 +199,8 @@ function split_system(ci::ClockInference{S}) where {S}
199199

200200
# breaks the system up into a continous and 0 or more discrete systems
201201
tss = similar(cid_to_eq, S)
202-
for (id, ieqs) in enumerate(cid_to_eq)
203-
ts_i = system_subset(ts, ieqs)
202+
for (id, (ieqs, ivars)) in enumerate(zip(cid_to_eq, cid_to_var))
203+
ts_i = system_subset(ts, ieqs, ivars)
204204
if id != continuous_id
205205
ts_i = shift_discrete_system(ts_i)
206206
@set! ts_i.structure.only_discrete = true

src/systems/systemstructure.jl

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -218,12 +218,12 @@ mutable struct TearingState{T <: AbstractSystem} <: AbstractTearingState{T}
218218
end
219219

220220
TransformationState(sys::AbstractSystem) = TearingState(sys)
221-
function system_subset(ts::TearingState, ieqs::Vector{Int})
221+
function system_subset(ts::TearingState, ieqs::Vector{Int}, ivars::Vector{Int})
222222
eqs = equations(ts)
223223
@set! ts.original_eqs = ts.original_eqs[ieqs]
224224
@set! ts.sys.eqs = eqs[ieqs]
225225
@set! ts.original_eqs = ts.original_eqs[ieqs]
226-
@set! ts.structure = system_subset(ts.structure, ieqs)
226+
@set! ts.structure = system_subset(ts.structure, ieqs, ivars)
227227
if all(eq -> eq.rhs isa StateMachineOperator, get_eqs(ts.sys))
228228
names = Symbol[]
229229
for eq in get_eqs(ts.sys)
@@ -240,22 +240,33 @@ function system_subset(ts::TearingState, ieqs::Vector{Int})
240240
else
241241
@set! ts.statemachines = eltype(ts.statemachines)[]
242242
end
243+
@set! ts.fullvars = ts.fullvars[ivars]
243244
ts
244245
end
245246

246-
function system_subset(structure::SystemStructure, ieqs::Vector{Int})
247-
@unpack graph, eq_to_diff = structure
247+
function system_subset(structure::SystemStructure, ieqs::Vector{Int}, ivars::Vector{Int})
248+
@unpack graph = structure
248249
fadj = Vector{Int}[]
249250
eq_to_diff = DiffGraph(length(ieqs))
251+
var_to_diff = DiffGraph(length(ivars))
252+
250253
ne = 0
254+
old_to_new_var = zeros(Int, ndsts(graph))
255+
for (i, iv) in enumerate(ivars)
256+
old_to_new_var[iv] = i
257+
structure.var_to_diff[iv] === nothing && continue
258+
var_to_diff[i] = old_to_new_var[structure.var_to_diff[iv]]
259+
end
251260
for (j, eq_i) in enumerate(ieqs)
252-
ivars = copy(graph.fadjlist[eq_i])
253-
ne += length(ivars)
254-
push!(fadj, ivars)
261+
var_adj = [old_to_new_var[i] for i in graph.fadjlist[eq_i]]
262+
@assert all(!iszero, var_adj)
263+
ne += length(var_adj)
264+
push!(fadj, var_adj)
255265
eq_to_diff[j] = structure.eq_to_diff[eq_i]
256266
end
257-
@set! structure.graph = complete(BipartiteGraph(ne, fadj, ndsts(graph)))
267+
@set! structure.graph = complete(BipartiteGraph(ne, fadj, length(ivars)))
258268
@set! structure.eq_to_diff = eq_to_diff
269+
@set! structure.var_to_diff = complete(var_to_diff)
259270
structure
260271
end
261272

@@ -440,7 +451,8 @@ function TearingState(sys; quick_cancel = false, check = true, sort_eqs = true)
440451
isdelay(v, iv) && continue
441452

442453
if !symbolic_contains(v, dvs)
443-
isvalid = iscall(v) && (operation(v) isa Shift || is_transparent_operator(operation(v)))
454+
isvalid = iscall(v) &&
455+
(operation(v) isa Shift || is_transparent_operator(operation(v)))
444456
v′ = v
445457
while !isvalid && iscall(v′) && operation(v′) isa Union{Differential, Shift}
446458
v′ = arguments(v′)[1]

0 commit comments

Comments
 (0)