Skip to content

Commit bbc1cbc

Browse files
Merge pull request #3804 from AayushSabharwal/as/control-and-error
fix: fix bad input ordering, make downstream tests pass
2 parents 14623bc + 173fc49 commit bbc1cbc

File tree

6 files changed

+102
-53
lines changed

6 files changed

+102
-53
lines changed

src/deprecations.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,9 @@ end
1212
const ODESystem = IntermediateDeprecationSystem
1313

1414
function IntermediateDeprecationSystem(args...; kwargs...)
15-
Base.depwarn("`ODESystem(args...; kwargs...)` is deprecated. Use `System(args...; kwargs...) instead`.", :ODESystem)
15+
Base.depwarn(
16+
"`ODESystem(args...; kwargs...)` is deprecated. Use `System(args...; kwargs...) instead`.",
17+
:ODESystem)
1618

1719
return System(args...; kwargs...)
1820
end

src/inputoutput.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,7 @@ function generate_control_function(sys::AbstractSystem, inputs = unbound_inputs(
200200
eval_module = @__MODULE__,
201201
kwargs...)
202202
isempty(inputs) && @warn("No unbound inputs were found in system.")
203-
if !iscomplete(sys)
203+
if !isscheduled(sys)
204204
sys = mtkcompile(sys; inputs, disturbance_inputs)
205205
end
206206
if disturbance_inputs !== nothing

src/linearization.jl

Lines changed: 50 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -565,6 +565,50 @@ function linearize_symbolic(sys::AbstractSystem, inputs,
565565
(; A, B, C, D, f_x, f_z, g_x, g_z, f_u, g_u, h_x, h_z, h_u), sys
566566
end
567567

568+
struct IONotFoundError <: Exception
569+
variant::String
570+
sysname::Symbol
571+
not_found::Any
572+
end
573+
574+
function Base.showerror(io::IO, err::IONotFoundError)
575+
println(io,
576+
"The following $(err.variant) provided to `mtkcompile` were not found in the system:")
577+
maybe_namespace_issue = false
578+
for var in err.not_found
579+
println(io, " ", var)
580+
if hasname(var) && startswith(string(getname(var)), string(err.sysname))
581+
maybe_namespace_issue = true
582+
end
583+
end
584+
if maybe_namespace_issue
585+
println(io, """
586+
Some of the missing variables are namespaced with the name of the system \
587+
`$(err.sysname)` passed to `mtkcompile`. This may be indicative of a namespacing \
588+
issue. `mtkcompile` requires that the $(err.variant) provided are not namespaced \
589+
with the name of the root system. This issue can occur when using `getproperty` \
590+
to access the variables passed as $(err.variant). For example:
591+
592+
```julia
593+
@named sys = MyModel()
594+
inputs = [sys.input_var]
595+
mtkcompile(sys; inputs)
596+
```
597+
598+
Here, `mtkcompile` expects the input to be named `input_var`, but since `sys`
599+
performs namespacing, it will be named `sys$(NAMESPACE_SEPARATOR)input_var`. To \
600+
fix this issue, namespacing can be temporarily disabled:
601+
602+
```julia
603+
@named sys = MyModel()
604+
sys_nns = toggle_namespacing(sys, false)
605+
inputs = [sys_nns.input_var]
606+
mtkcompile(sys; inputs)
607+
```
608+
""")
609+
end
610+
end
611+
568612
"""
569613
Modify the variable metadata of system variables to indicate which ones are inputs, outputs, and disturbances. Needed for `inputs`, `outputs`, `disturbances`, `unbound_inputs`, `unbound_outputs` to return the proper subsets.
570614
"""
@@ -605,19 +649,16 @@ function markio!(state, orig_inputs, inputs, outputs, disturbances; check = true
605649
if check
606650
ikeys = keys(filter(!last, inputset))
607651
if !isempty(ikeys)
608-
error(
609-
"Some specified inputs were not found in system. The following variables were not found ",
610-
ikeys)
652+
throw(IONotFoundError("inputs", nameof(state.sys), ikeys))
611653
end
612654
dkeys = keys(filter(!last, disturbanceset))
613655
if !isempty(dkeys)
614-
error(
615-
"Specified disturbance inputs were not found in system. The following variables were not found ",
616-
ikeys)
656+
throw(IONotFoundError("disturbance inputs", nameof(state.sys), ikeys))
657+
end
658+
okeys = keys(filter(!last, outputset))
659+
if !isempty(okeys)
660+
throw(IONotFoundError("outputs", nameof(state.sys), okeys))
617661
end
618-
(all(values(outputset)) || error(
619-
"Some specified outputs were not found in system. The following Dict indicates the found variables ",
620-
outputset))
621662
end
622663
state, orig_inputs
623664
end

src/systems/abstractsystem.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -652,6 +652,8 @@ function complete(
652652
# Ideally we'd do `get_ps` but if `flatten = false`
653653
# we don't get all of them. So we call `parameters`.
654654
all_ps = parameters(sys; initial_parameters = true)
655+
# inputs have to be maintained in a specific order
656+
input_vars = inputs(sys)
655657
if !isempty(all_ps)
656658
# reorder parameters by portions
657659
ps_split = reorder_parameters(sys, all_ps)
@@ -670,6 +672,12 @@ function complete(
670672
end
671673
ordered_ps = vcat(
672674
ordered_ps, reduce(vcat, ps_split; init = eltype(ordered_ps)[]))
675+
if isscheduled(sys)
676+
# ensure inputs are sorted
677+
input_idxs = findfirst.(isequal.(input_vars), (ordered_ps,))
678+
@assert all(!isnothing, input_idxs)
679+
@assert issorted(input_idxs)
680+
end
673681
@set! sys.ps = ordered_ps
674682
end
675683
elseif has_index_cache(sys)

src/systems/index_cache.jl

Lines changed: 38 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -97,8 +97,8 @@ function IndexCache(sys::AbstractSystem)
9797
end
9898
end
9999

100-
tunable_buffers = Dict{Any, Set{BasicSymbolic}}()
101-
initial_param_buffers = Dict{Any, Set{BasicSymbolic}}()
100+
tunable_pars = BasicSymbolic[]
101+
initial_pars = BasicSymbolic[]
102102
constant_buffers = Dict{Any, Set{BasicSymbolic}}()
103103
nonnumeric_buffers = Dict{Any, Set{SymbolicParam}}()
104104

@@ -107,6 +107,10 @@ function IndexCache(sys::AbstractSystem)
107107
buf = get!(buffers, ctype, S())
108108
push!(buf, sym)
109109
end
110+
function insert_by_type!(buffers::Vector{BasicSymbolic}, sym, ctype)
111+
sym = unwrap(sym)
112+
push!(buffers, sym)
113+
end
110114

111115
disc_param_callbacks = Dict{SymbolicParam, Set{Int}}()
112116
events = vcat(continuous_events(sys), discrete_events(sys))
@@ -210,9 +214,9 @@ function IndexCache(sys::AbstractSystem)
210214
ctype <: AbstractArray{Real} ||
211215
ctype <: AbstractArray{<:AbstractFloat})
212216
if iscall(p) && operation(p) isa Initial
213-
initial_param_buffers
217+
initial_pars
214218
else
215-
tunable_buffers
219+
tunable_pars
216220
end
217221
else
218222
constant_buffers
@@ -255,47 +259,41 @@ function IndexCache(sys::AbstractSystem)
255259

256260
tunable_idxs = TunableIndexMap()
257261
tunable_buffer_size = 0
258-
bufferlist = is_initializesystem(sys) ? (tunable_buffers, initial_param_buffers) :
259-
(tunable_buffers,)
260-
for buffers in bufferlist
261-
for (i, (_, buf)) in enumerate(buffers)
262-
for (j, p) in enumerate(buf)
263-
idx = if size(p) == ()
264-
tunable_buffer_size + 1
265-
else
266-
reshape(
267-
(tunable_buffer_size + 1):(tunable_buffer_size + length(p)), size(p))
268-
end
269-
tunable_buffer_size += length(p)
270-
tunable_idxs[p] = idx
271-
tunable_idxs[default_toterm(p)] = idx
272-
if hasname(p) && (!iscall(p) || operation(p) !== getindex)
273-
symbol_to_variable[getname(p)] = p
274-
symbol_to_variable[getname(default_toterm(p))] = p
275-
end
276-
end
262+
if is_initializesystem(sys)
263+
append!(tunable_pars, initial_pars)
264+
empty!(initial_pars)
265+
end
266+
for p in tunable_pars
267+
idx = if size(p) == ()
268+
tunable_buffer_size + 1
269+
else
270+
reshape(
271+
(tunable_buffer_size + 1):(tunable_buffer_size + length(p)), size(p))
272+
end
273+
tunable_buffer_size += length(p)
274+
tunable_idxs[p] = idx
275+
tunable_idxs[default_toterm(p)] = idx
276+
if hasname(p) && (!iscall(p) || operation(p) !== getindex)
277+
symbol_to_variable[getname(p)] = p
278+
symbol_to_variable[getname(default_toterm(p))] = p
277279
end
278280
end
279281

280282
initials_idxs = TunableIndexMap()
281283
initials_buffer_size = 0
282-
if !is_initializesystem(sys)
283-
for (i, (_, buf)) in enumerate(initial_param_buffers)
284-
for (j, p) in enumerate(buf)
285-
idx = if size(p) == ()
286-
initials_buffer_size + 1
287-
else
288-
reshape(
289-
(initials_buffer_size + 1):(initials_buffer_size + length(p)), size(p))
290-
end
291-
initials_buffer_size += length(p)
292-
initials_idxs[p] = idx
293-
initials_idxs[default_toterm(p)] = idx
294-
if hasname(p) && (!iscall(p) || operation(p) !== getindex)
295-
symbol_to_variable[getname(p)] = p
296-
symbol_to_variable[getname(default_toterm(p))] = p
297-
end
298-
end
284+
for p in initial_pars
285+
idx = if size(p) == ()
286+
initials_buffer_size + 1
287+
else
288+
reshape(
289+
(initials_buffer_size + 1):(initials_buffer_size + length(p)), size(p))
290+
end
291+
initials_buffer_size += length(p)
292+
initials_idxs[p] = idx
293+
initials_idxs[default_toterm(p)] = idx
294+
if hasname(p) && (!iscall(p) || operation(p) !== getindex)
295+
symbol_to_variable[getname(p)] = p
296+
symbol_to_variable[getname(default_toterm(p))] = p
299297
end
300298
end
301299

test/linearize.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -151,12 +151,12 @@ lsys = ModelingToolkit.reorder_unknowns(lsys, desired_order, reverse(desired_ord
151151
@test lsys.D == [4400 -4400]
152152

153153
## Test that there is a warning when input is misspecified
154-
@test_throws "Some specified inputs were not found" linearize(pid,
154+
@test_throws ["inputs provided to `mtkcompile`", "not found"] linearize(pid,
155155
[
156156
pid.reference.u,
157157
pid.measurement.u
158158
], [ctr_output.u])
159-
@test_throws "Some specified outputs were not found" linearize(pid,
159+
@test_throws ["outputs provided to `mtkcompile`", "not found"] linearize(pid,
160160
[
161161
reference.u,
162162
measurement.u

0 commit comments

Comments
 (0)