Skip to content

Commit fbc2f40

Browse files
Merge pull request #2908 from AayushSabharwal/as/better-tunables
refactor: turn tunables portion into a Vector{T}
2 parents e64c479 + 7291dc8 commit fbc2f40

18 files changed

+359
-205
lines changed

src/inputoutput.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,7 @@ function generate_control_function(sys::AbstractODESystem, inputs = unbound_inpu
243243
end
244244
process = get_postprocess_fbody(sys)
245245
f = build_function(rhss, args...; postprocess_fbody = process,
246-
expression = Val{true}, kwargs...)
246+
expression = Val{true}, wrap_code = wrap_array_vars(sys, rhss; dvs, ps), kwargs...)
247247
f = eval_or_rgf.(f; eval_expression, eval_module)
248248
(; f, dvs, ps, io_sys = sys)
249249
end

src/systems/abstractsystem.jl

Lines changed: 94 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -223,15 +223,70 @@ function wrap_assignments(isscalar, assignments; let_block = false)
223223
end
224224
end
225225

226-
function wrap_array_vars(sys::AbstractSystem, exprs; dvs = unknowns(sys))
226+
function wrap_array_vars(
227+
sys::AbstractSystem, exprs; dvs = unknowns(sys), ps = parameters(sys))
227228
isscalar = !(exprs isa AbstractArray)
228229
array_vars = Dict{Any, AbstractArray{Int}}()
229-
for (j, x) in enumerate(dvs)
230-
if iscall(x) && operation(x) == getindex
231-
arg = arguments(x)[1]
232-
inds = get!(() -> Int[], array_vars, arg)
233-
push!(inds, j)
230+
if dvs !== nothing
231+
for (j, x) in enumerate(dvs)
232+
if iscall(x) && operation(x) == getindex
233+
arg = arguments(x)[1]
234+
inds = get!(() -> Int[], array_vars, arg)
235+
push!(inds, j)
236+
end
237+
end
238+
uind = 1
239+
else
240+
uind = 0
241+
end
242+
# tunables are scalarized and concatenated, so we need to have assignments
243+
# for the non-scalarized versions
244+
array_tunables = Dict{Any, Tuple{AbstractArray{Int}, Tuple{Vararg{Int}}}}()
245+
# Other parameters may be scalarized arrays but used in the vector form
246+
other_array_parameters = Dict{Any, Any}()
247+
248+
if ps isa Tuple && eltype(ps) <: AbstractArray
249+
ps = Iterators.flatten(ps)
250+
end
251+
for p in ps
252+
p = unwrap(p)
253+
if iscall(p) && operation(p) == getindex
254+
p = arguments(p)[1]
255+
end
256+
symtype(p) <: AbstractArray && Symbolics.shape(p) != Symbolics.Unknown() || continue
257+
scal = collect(p)
258+
# all scalarized variables are in `ps`
259+
any(isequal(p), ps) || all(x -> any(isequal(x), ps), scal) || continue
260+
(haskey(array_tunables, p) || haskey(other_array_parameters, p)) && continue
261+
262+
idx = parameter_index(sys, p)
263+
idx isa Int && continue
264+
if idx isa ParameterIndex
265+
if idx.portion != SciMLStructures.Tunable()
266+
continue
267+
end
268+
idxs = vec(idx.idx)
269+
sz = size(idx.idx)
270+
else
271+
# idx === nothing
272+
idxs = map(Base.Fix1(parameter_index, sys), scal)
273+
if all(x -> x isa ParameterIndex && x.portion isa SciMLStructures.Tunable, idxs)
274+
idxs = map(x -> x.idx, idxs)
275+
end
276+
if !all(x -> x isa Int, idxs)
277+
other_array_parameters[p] = scal
278+
continue
279+
end
280+
281+
sz = size(idxs)
282+
if vec(idxs) == idxs[begin]:idxs[end]
283+
idxs = idxs[begin]:idxs[end]
284+
elseif vec(idxs) == idxs[begin]:-1:idxs[end]
285+
idxs = idxs[begin]:-1:idxs[end]
286+
end
287+
idxs = vec(idxs)
234288
end
289+
array_tunables[p] = (idxs, sz)
235290
end
236291
for (k, inds) in array_vars
237292
if inds == (inds′ = inds[1]:inds[end])
@@ -244,7 +299,13 @@ function wrap_array_vars(sys::AbstractSystem, exprs; dvs = unknowns(sys))
244299
expr.args,
245300
[],
246301
Let(
247-
[k :(view($(expr.args[1].name), $v)) for (k, v) in array_vars],
302+
vcat(
303+
[k :(view($(expr.args[uind].name), $v)) for (k, v) in array_vars],
304+
[k :(reshape(view($(expr.args[uind + 1].name), $idxs), $sz))
305+
for (k, (idxs, sz)) in array_tunables],
306+
[k Code.MakeArray(v, symtype(k))
307+
for (k, v) in other_array_parameters]
308+
),
248309
expr.body,
249310
false
250311
)
@@ -256,7 +317,13 @@ function wrap_array_vars(sys::AbstractSystem, exprs; dvs = unknowns(sys))
256317
expr.args,
257318
[],
258319
Let(
259-
[k :(view($(expr.args[1].name), $v)) for (k, v) in array_vars],
320+
vcat(
321+
[k :(view($(expr.args[uind].name), $v)) for (k, v) in array_vars],
322+
[k :(reshape(view($(expr.args[uind + 1].name), $idxs), $sz))
323+
for (k, (idxs, sz)) in array_tunables],
324+
[k Code.MakeArray(v, symtype(k))
325+
for (k, v) in other_array_parameters]
326+
),
260327
expr.body,
261328
false
262329
)
@@ -267,7 +334,14 @@ function wrap_array_vars(sys::AbstractSystem, exprs; dvs = unknowns(sys))
267334
expr.args,
268335
[],
269336
Let(
270-
[k :(view($(expr.args[2].name), $v)) for (k, v) in array_vars],
337+
vcat(
338+
[k :(view($(expr.args[uind + 1].name), $v))
339+
for (k, v) in array_vars],
340+
[k :(reshape(view($(expr.args[uind + 2].name), $idxs), $sz))
341+
for (k, (idxs, sz)) in array_tunables],
342+
[k Code.MakeArray(v, symtype(k))
343+
for (k, v) in other_array_parameters]
344+
),
271345
expr.body,
272346
false
273347
)
@@ -455,15 +529,18 @@ function SymbolicIndexingInterface.is_parameter(sys::AbstractSystem, sym)
455529
return unwrap(sym) in 1:length(parameter_symbols(sys))
456530
end
457531
return any(isequal(sym), parameter_symbols(sys)) ||
458-
hasname(sym) && is_parameter(sys, getname(sym))
532+
hasname(sym) && !(iscall(sym) && operation(sym) == getindex) &&
533+
is_parameter(sys, getname(sym))
459534
end
460535

461536
function SymbolicIndexingInterface.is_parameter(sys::AbstractSystem, sym::Symbol)
462537
if has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing
463538
return is_parameter(ic, sym)
464539
end
465540

466-
named_parameters = [getname(sym) for sym in parameter_symbols(sys) if hasname(sym)]
541+
named_parameters = [getname(x)
542+
for x in parameter_symbols(sys)
543+
if hasname(x) && !(iscall(x) && operation(x) == getindex)]
467544
return any(isequal(sym), named_parameters) ||
468545
count(NAMESPACE_SEPARATOR, string(sym)) == 1 &&
469546
count(isequal(sym),
@@ -499,7 +576,7 @@ function SymbolicIndexingInterface.parameter_index(sys::AbstractSystem, sym)
499576
return sym
500577
end
501578
idx = findfirst(isequal(sym), parameter_symbols(sys))
502-
if idx === nothing && hasname(sym)
579+
if idx === nothing && hasname(sym) && !(iscall(sym) && operation(sym) == getindex)
503580
idx = parameter_index(sys, getname(sym))
504581
end
505582
return idx
@@ -515,13 +592,16 @@ function SymbolicIndexingInterface.parameter_index(sys::AbstractSystem, sym::Sym
515592
return idx
516593
end
517594
end
518-
idx = findfirst(isequal(sym), getname.(parameter_symbols(sys)))
595+
pnames = [getname(x)
596+
for x in parameter_symbols(sys)
597+
if hasname(x) && !(iscall(x) && operation(x) == getindex)]
598+
idx = findfirst(isequal(sym), pnames)
519599
if idx !== nothing
520600
return idx
521601
elseif count(NAMESPACE_SEPARATOR, string(sym)) == 1
522602
return findfirst(isequal(sym),
523603
Symbol.(
524-
nameof(sys), NAMESPACE_SEPARATOR_SYMBOL, getname.(parameter_symbols(sys))))
604+
nameof(sys), NAMESPACE_SEPARATOR_SYMBOL, pnames))
525605
end
526606
return nothing
527607
end

src/systems/callbacks.jl

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -354,7 +354,8 @@ function compile_condition(cb::SymbolicDiscreteCallback, sys, dvs, ps;
354354
condit = substitute(condit, cmap)
355355
end
356356
expr = build_function(
357-
condit, u, t, p...; expression = Val{true}, wrap_code = condition_header(sys),
357+
condit, u, t, p...; expression = Val{true},
358+
wrap_code = condition_header(sys) .∘ wrap_array_vars(sys, condit; dvs, ps),
358359
kwargs...)
359360
if expression == Val{true}
360361
return expr
@@ -411,10 +412,8 @@ function compile_affect(eqs::Vector{Equation}, sys, dvs, ps; outputidxs = nothin
411412
update_inds = map(sym -> unknownind[sym], update_vars)
412413
elseif isparameter(first(lhss)) && alleq
413414
if has_index_cache(sys) && get_index_cache(sys) !== nothing
414-
ic = get_index_cache(sys)
415415
update_inds = map(update_vars) do sym
416-
pind = parameter_index(sys, sym)
417-
discrete_linear_index(ic, pind)
416+
return parameter_index(sys, sym)
418417
end
419418
else
420419
psind = Dict(reverse(en) for en in enumerate(ps))
@@ -428,6 +427,7 @@ function compile_affect(eqs::Vector{Equation}, sys, dvs, ps; outputidxs = nothin
428427
update_inds = outputidxs
429428
end
430429

430+
_ps = ps
431431
ps = reorder_parameters(sys, ps)
432432
if checkvars
433433
u = map(x -> time_varying_as_func(value(x), sys), dvs)
@@ -440,7 +440,8 @@ function compile_affect(eqs::Vector{Equation}, sys, dvs, ps; outputidxs = nothin
440440
integ = gensym(:MTKIntegrator)
441441
pre = get_preprocess_constants(rhss)
442442
rf_oop, rf_ip = build_function(rhss, u, p..., t; expression = Val{true},
443-
wrap_code = add_integrator_header(sys, integ, outvar),
443+
wrap_code = add_integrator_header(sys, integ, outvar) .∘
444+
wrap_array_vars(sys, rhss; dvs, ps = _ps),
444445
outputidxs = update_inds,
445446
postprocess_fbody = pre,
446447
kwargs...)

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ end
8787

8888
function generate_tgrad(
8989
sys::AbstractODESystem, dvs = unknowns(sys), ps = full_parameters(sys);
90-
simplify = false, kwargs...)
90+
simplify = false, wrap_code = identity, kwargs...)
9191
tgrad = calculate_tgrad(sys, simplify = simplify)
9292
pre = get_preprocess_constants(tgrad)
9393
p = if has_index_cache(sys) && get_index_cache(sys) !== nothing
@@ -97,29 +97,33 @@ function generate_tgrad(
9797
else
9898
(ps,)
9999
end
100+
wrap_code = wrap_code .∘ wrap_array_vars(sys, tgrad; dvs, ps)
100101
return build_function(tgrad,
101102
dvs,
102103
p...,
103104
get_iv(sys);
104105
postprocess_fbody = pre,
106+
wrap_code,
105107
kwargs...)
106108
end
107109

108110
function generate_jacobian(sys::AbstractODESystem, dvs = unknowns(sys),
109111
ps = full_parameters(sys);
110-
simplify = false, sparse = false, kwargs...)
112+
simplify = false, sparse = false, wrap_code = identity, kwargs...)
111113
jac = calculate_jacobian(sys; simplify = simplify, sparse = sparse)
112114
pre = get_preprocess_constants(jac)
113115
p = if has_index_cache(sys) && get_index_cache(sys) !== nothing
114116
reorder_parameters(get_index_cache(sys), ps)
115117
else
116118
(ps,)
117119
end
120+
wrap_code = wrap_code .∘ wrap_array_vars(sys, jac; dvs, ps)
118121
return build_function(jac,
119122
dvs,
120123
p...,
121124
get_iv(sys);
122125
postprocess_fbody = pre,
126+
wrap_code,
123127
kwargs...)
124128
end
125129

@@ -188,12 +192,12 @@ function generate_function(sys::AbstractODESystem, dvs = unknowns(sys),
188192
if implicit_dae
189193
build_function(rhss, ddvs, u, p..., t; postprocess_fbody = pre,
190194
states = sol_states,
191-
wrap_code = wrap_code .∘ wrap_array_vars(sys, rhss; dvs),
195+
wrap_code = wrap_code .∘ wrap_array_vars(sys, rhss; dvs, ps),
192196
kwargs...)
193197
else
194198
build_function(rhss, u, p..., t; postprocess_fbody = pre,
195199
states = sol_states,
196-
wrap_code = wrap_code .∘ wrap_array_vars(sys, rhss; dvs),
200+
wrap_code = wrap_code .∘ wrap_array_vars(sys, rhss; dvs, ps),
197201
kwargs...)
198202
end
199203
end

src/systems/diffeqs/odesystem.jl

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -485,6 +485,7 @@ function build_explicit_observed_function(sys, ts;
485485
if inputs !== nothing
486486
ps = setdiff(ps, inputs) # Inputs have been converted to parameters by io_preprocessing, remove those from the parameter list
487487
end
488+
_ps = ps
488489
if ps isa Tuple
489490
ps = DestructuredArgs.(ps, inbounds = !checkbounds)
490491
elseif has_index_cache(sys) && get_index_cache(sys) !== nothing
@@ -505,19 +506,24 @@ function build_explicit_observed_function(sys, ts;
505506
end
506507
pre = get_postprocess_fbody(sys)
507508

509+
array_wrapper = if param_only
510+
wrap_array_vars(sys, ts; ps = _ps, dvs = nothing)
511+
else
512+
wrap_array_vars(sys, ts; ps = _ps)
513+
end
508514
# Need to keep old method of building the function since it uses `output_type`,
509515
# which can't be provided to `build_function`
510516
oop_fn = Func(args, [],
511517
pre(Let(obsexprs,
512518
isscalar ? ts[1] : MakeArray(ts, output_type),
513-
false))) |> wrap_array_vars(sys, ts)[1] |> toexpr
519+
false))) |> array_wrapper[1] |> toexpr
514520
oop_fn = expression ? oop_fn : eval_or_rgf(oop_fn; eval_expression, eval_module)
515521

516522
if !isscalar
517523
iip_fn = build_function(ts,
518524
args...;
519525
postprocess_fbody = pre,
520-
wrap_code = wrap_array_vars(sys, ts) .∘ wrap_assignments(isscalar, obsexprs),
526+
wrap_code = array_wrapper .∘ wrap_assignments(isscalar, obsexprs),
521527
expression = Val{true})[2]
522528
if !expression
523529
iip_fn = eval_or_rgf(iip_fn; eval_expression, eval_module)

src/systems/discrete_system/discrete_system.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -218,8 +218,10 @@ function flatten(sys::DiscreteSystem, noeqs = false)
218218
end
219219

220220
function generate_function(
221-
sys::DiscreteSystem, dvs = unknowns(sys), ps = full_parameters(sys); kwargs...)
222-
generate_custom_function(sys, [eq.rhs for eq in equations(sys)], dvs, ps; kwargs...)
221+
sys::DiscreteSystem, dvs = unknowns(sys), ps = full_parameters(sys); wrap_code = identity, kwargs...)
222+
exprs = [eq.rhs for eq in equations(sys)]
223+
wrap_code = wrap_code .∘ wrap_array_vars(sys, exprs)
224+
generate_custom_function(sys, exprs, dvs, ps; wrap_code, kwargs...)
223225
end
224226

225227
function process_DiscreteProblem(constructor, sys::DiscreteSystem, u0map, parammap;

0 commit comments

Comments
 (0)