Skip to content

Commit 56ca2b3

Browse files
BenChungAayushSabharwal
authored andcommitted
Early work on the new discrete backend for MTK
1 parent f807dbc commit 56ca2b3

File tree

3 files changed

+48
-13
lines changed

3 files changed

+48
-13
lines changed

src/systems/clock_inference.jl

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ function infer_clocks!(ci::ClockInference)
100100
c = BitSet(c′)
101101
idxs = intersect(c, inferred)
102102
isempty(idxs) && continue
103-
if !allequal(var_domain[i] for i in idxs)
103+
if !allequal(iscontinuous(var_domain[i]) for i in idxs)
104104
display(fullvars[c′])
105105
throw(ClockInferenceException("Clocks are not consistent in connected component $(fullvars[c′])"))
106106
end
@@ -155,6 +155,9 @@ function split_system(ci::ClockInference{S}) where {S}
155155
cid_to_var = Vector{Int}[]
156156
# cid_counter = number of clocks
157157
cid_counter = Ref(0)
158+
159+
# populates clock_to_id and id_to_clock
160+
# checks if there is a continuous_id (for some reason? clock to id does this too)
158161
for (i, d) in enumerate(eq_domain)
159162
cid = let cid_counter = cid_counter, id_to_clock = id_to_clock,
160163
continuous_id = continuous_id
@@ -174,9 +177,13 @@ function split_system(ci::ClockInference{S}) where {S}
174177
resize_or_push!(cid_to_eq, i, cid)
175178
end
176179
continuous_id = continuous_id[]
180+
# for each clock partition what are the input (indexes/vars)
177181
input_idxs = map(_ -> Int[], 1:cid_counter[])
178182
inputs = map(_ -> Any[], 1:cid_counter[])
183+
# var_domain corresponds to fullvars/all variables in the system
179184
nvv = length(var_domain)
185+
# put variables into the right clock partition
186+
# keep track of inputs to each partition
180187
for i in 1:nvv
181188
d = var_domain[i]
182189
cid = get(clock_to_id, d, 0)
@@ -190,6 +197,7 @@ function split_system(ci::ClockInference{S}) where {S}
190197
resize_or_push!(cid_to_var, i, cid)
191198
end
192199

200+
# breaks the system up into a continous and 0 or more discrete systems
193201
tss = similar(cid_to_eq, S)
194202
for (id, ieqs) in enumerate(cid_to_eq)
195203
ts_i = system_subset(ts, ieqs)
@@ -199,6 +207,7 @@ function split_system(ci::ClockInference{S}) where {S}
199207
end
200208
tss[id] = ts_i
201209
end
210+
# put the continous system at the back
202211
if continuous_id != 0
203212
tss[continuous_id], tss[end] = tss[end], tss[continuous_id]
204213
inputs[continuous_id], inputs[end] = inputs[end], inputs[continuous_id]

src/systems/systems.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ function mtkcompile(
3636
isscheduled(sys) && throw(RepeatedStructuralSimplificationError())
3737
newsys′ = __mtkcompile(sys; simplify,
3838
allow_symbolic, allow_parameter, conservative, fully_determined,
39-
inputs, outputs, disturbance_inputs,
39+
inputs, outputs, disturbance_inputs, additional_passes,
4040
kwargs...)
4141
if newsys′ isa Tuple
4242
@assert length(newsys′) == 2
@@ -291,3 +291,8 @@ function map_variables_to_equations(sys::AbstractSystem; rename_dummy_derivative
291291

292292
return mapping
293293
end
294+
295+
"""
296+
Mark whether an extra pass `p` can support compiling discrete systems.
297+
"""
298+
discrete_compile_pass(p) = false

src/systems/systemstructure.jl

Lines changed: 32 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -706,19 +706,40 @@ function mtkcompile!(state::TearingState; simplify = false,
706706
time_domains = merge(Dict(state.fullvars .=> ci.var_domain),
707707
Dict(default_toterm.(state.fullvars) .=> ci.var_domain))
708708
tss, clocked_inputs, continuous_id, id_to_clock = ModelingToolkit.split_system(ci)
709+
if continuous_id == 0
710+
# do a trait check here - handle fully discrete system
711+
additional_passes = get(kwargs, :additional_passes, nothing)
712+
if !isnothing(additional_passes) && any(discrete_compile_pass, additional_passes)
713+
# take the first discrete compilation pass given for now
714+
discrete_pass_idx = findfirst(discrete_compile_pass, additional_passes)
715+
discrete_compile = additional_passes[discrete_pass_idx]
716+
deleteat!(additional_passes, discrete_pass_idx)
717+
return discrete_compile(tss, clocked_inputs)
718+
end
719+
throw(HybridSystemNotSupportedException("""
720+
Discrete systems with multiple clocks are not supported with the standard \
721+
MTK compiler.
722+
"""))
723+
end
709724
if length(tss) > 1
710-
if continuous_id == 0
711-
throw(HybridSystemNotSupportedException("""
712-
Discrete systems with multiple clocks are not supported with the standard \
713-
MTK compiler.
714-
"""))
715-
else
716-
throw(HybridSystemNotSupportedException("""
717-
Hybrid continuous-discrete systems are currently not supported with \
718-
the standard MTK compiler. This system requires JuliaSimCompiler.jl, \
719-
see https://help.juliahub.com/juliasimcompiler/stable/
720-
"""))
725+
# simplify as normal
726+
sys = _mtkcompile!(tss[continuous_id]; simplify,
727+
inputs = [inputs; clocked_inputs[continuous_id]], outputs, disturbance_inputs,
728+
check_consistency, fully_determined,
729+
kwargs...)
730+
if !isnothing(additional_passes) && any(discrete_compile_pass, additional_passes)
731+
discrete_pass_idx = findfirst(discrete_compile_pass, additional_passes)
732+
discrete_compile = additional_passes[discrete_pass_idx]
733+
deleteat!(additional_passes, discrete_pass_idx)
734+
# in the case of a hybrid system, the discrete_compile pass should take the currents of sys.discrete_subsystems
735+
# and modifies discrete_subsystems to bea tuple of the io and anything else, while adding or manipulating the rest of sys as needed
736+
return discrete_compile(sys, tss[2:end], inputs)
721737
end
738+
throw(HybridSystemNotSupportedException("""
739+
Hybrid continuous-discrete systems are currently not supported with \
740+
the standard MTK compiler. This system requires JuliaSimCompiler.jl, \
741+
see https://help.juliahub.com/juliasimcompiler/stable/
742+
"""))
722743
end
723744
if get_is_discrete(state.sys) ||
724745
continuous_id == 1 && any(Base.Fix2(isoperator, Shift), state.fullvars)

0 commit comments

Comments
 (0)