From 209e8bb995a5eb0d16adaa3f5bfd1390466d796d Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Wed, 9 Jul 2025 13:34:20 +0530 Subject: [PATCH] fix: maintain order of inputs during `complete` --- src/systems/abstractsystem.jl | 8 ++++ src/systems/index_cache.jl | 78 +++++++++++++++++------------------ 2 files changed, 46 insertions(+), 40 deletions(-) diff --git a/src/systems/abstractsystem.jl b/src/systems/abstractsystem.jl index d78c48a857..71d8a4f96e 100644 --- a/src/systems/abstractsystem.jl +++ b/src/systems/abstractsystem.jl @@ -796,6 +796,8 @@ function complete( # Ideally we'd do `get_ps` but if `flatten = false` # we don't get all of them. So we call `parameters`. all_ps = parameters(sys; initial_parameters = true) + # inputs have to be maintained in a specific order + input_vars = inputs(sys) if !isempty(all_ps) # reorder parameters by portions ps_split = reorder_parameters(sys, all_ps) @@ -814,6 +816,12 @@ function complete( end ordered_ps = vcat( ordered_ps, reduce(vcat, ps_split; init = eltype(ordered_ps)[])) + if isscheduled(sys) + # ensure inputs are sorted + input_idxs = findfirst.(isequal.(input_vars), (ordered_ps,)) + @assert all(!isnothing, input_idxs) + @assert issorted(input_idxs) + end @set! sys.ps = ordered_ps end elseif has_index_cache(sys) diff --git a/src/systems/index_cache.jl b/src/systems/index_cache.jl index ae4feec62b..26fb356459 100644 --- a/src/systems/index_cache.jl +++ b/src/systems/index_cache.jl @@ -97,8 +97,8 @@ function IndexCache(sys::AbstractSystem) end end - tunable_buffers = Dict{Any, Set{BasicSymbolic}}() - initial_param_buffers = Dict{Any, Set{BasicSymbolic}}() + tunable_pars = BasicSymbolic[] + initial_pars = BasicSymbolic[] constant_buffers = Dict{Any, Set{BasicSymbolic}}() nonnumeric_buffers = Dict{Any, Set{SymbolicParam}}() @@ -107,6 +107,10 @@ function IndexCache(sys::AbstractSystem) buf = get!(buffers, ctype, S()) push!(buf, sym) end + function insert_by_type!(buffers::Vector{BasicSymbolic}, sym, ctype) + sym = unwrap(sym) + push!(buffers, sym) + end disc_param_callbacks = Dict{SymbolicParam, Set{Int}}() events = vcat(continuous_events(sys), discrete_events(sys)) @@ -210,9 +214,9 @@ function IndexCache(sys::AbstractSystem) ctype <: AbstractArray{Real} || ctype <: AbstractArray{<:AbstractFloat}) if iscall(p) && operation(p) isa Initial - initial_param_buffers + initial_pars else - tunable_buffers + tunable_pars end else constant_buffers @@ -253,47 +257,41 @@ function IndexCache(sys::AbstractSystem) tunable_idxs = TunableIndexMap() tunable_buffer_size = 0 - bufferlist = is_initializesystem(sys) ? (tunable_buffers, initial_param_buffers) : - (tunable_buffers,) - for buffers in bufferlist - for (i, (_, buf)) in enumerate(buffers) - for (j, p) in enumerate(buf) - idx = if size(p) == () - tunable_buffer_size + 1 - else - reshape( - (tunable_buffer_size + 1):(tunable_buffer_size + length(p)), size(p)) - end - tunable_buffer_size += length(p) - tunable_idxs[p] = idx - tunable_idxs[default_toterm(p)] = idx - if hasname(p) && (!iscall(p) || operation(p) !== getindex) - symbol_to_variable[getname(p)] = p - symbol_to_variable[getname(default_toterm(p))] = p - end - end + if is_initializesystem(sys) + append!(tunable_pars, initial_pars) + empty!(initial_pars) + end + for p in tunable_pars + idx = if size(p) == () + tunable_buffer_size + 1 + else + reshape( + (tunable_buffer_size + 1):(tunable_buffer_size + length(p)), size(p)) + end + tunable_buffer_size += length(p) + tunable_idxs[p] = idx + tunable_idxs[default_toterm(p)] = idx + if hasname(p) && (!iscall(p) || operation(p) !== getindex) + symbol_to_variable[getname(p)] = p + symbol_to_variable[getname(default_toterm(p))] = p end end initials_idxs = TunableIndexMap() initials_buffer_size = 0 - if !is_initializesystem(sys) - for (i, (_, buf)) in enumerate(initial_param_buffers) - for (j, p) in enumerate(buf) - idx = if size(p) == () - initials_buffer_size + 1 - else - reshape( - (initials_buffer_size + 1):(initials_buffer_size + length(p)), size(p)) - end - initials_buffer_size += length(p) - initials_idxs[p] = idx - initials_idxs[default_toterm(p)] = idx - if hasname(p) && (!iscall(p) || operation(p) !== getindex) - symbol_to_variable[getname(p)] = p - symbol_to_variable[getname(default_toterm(p))] = p - end - end + for p in initial_pars + idx = if size(p) == () + initials_buffer_size + 1 + else + reshape( + (initials_buffer_size + 1):(initials_buffer_size + length(p)), size(p)) + end + initials_buffer_size += length(p) + initials_idxs[p] = idx + initials_idxs[default_toterm(p)] = idx + if hasname(p) && (!iscall(p) || operation(p) !== getindex) + symbol_to_variable[getname(p)] = p + symbol_to_variable[getname(default_toterm(p))] = p end end