Skip to content

Commit 7291dc8

Browse files
fix: better handling of (possibly scalarized) array parameters
1 parent 323380f commit 7291dc8

File tree

3 files changed

+81
-30
lines changed

3 files changed

+81
-30
lines changed

src/systems/abstractsystem.jl

Lines changed: 65 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -241,27 +241,52 @@ function wrap_array_vars(
241241
end
242242
# tunables are scalarized and concatenated, so we need to have assignments
243243
# for the non-scalarized versions
244-
array_tunables = Dict{Any, AbstractArray{Int}}()
245-
for p in ps
246-
idx = parameter_index(sys, p)
247-
idx isa ParameterIndex || continue
248-
idx.portion isa SciMLStructures.Tunable || continue
249-
idx.idx isa AbstractArray || continue
250-
array_tunables[p] = idx.idx
251-
end
244+
array_tunables = Dict{Any, Tuple{AbstractArray{Int}, Tuple{Vararg{Int}}}}()
252245
# Other parameters may be scalarized arrays but used in the vector form
253-
other_array_parameters = Assignment[]
246+
other_array_parameters = Dict{Any, Any}()
247+
248+
if ps isa Tuple && eltype(ps) <: AbstractArray
249+
ps = Iterators.flatten(ps)
250+
end
254251
for p in ps
252+
p = unwrap(p)
253+
if iscall(p) && operation(p) == getindex
254+
p = arguments(p)[1]
255+
end
256+
symtype(p) <: AbstractArray && Symbolics.shape(p) != Symbolics.Unknown() || continue
257+
scal = collect(p)
258+
# all scalarized variables are in `ps`
259+
any(isequal(p), ps) || all(x -> any(isequal(x), ps), scal) || continue
260+
(haskey(array_tunables, p) || haskey(other_array_parameters, p)) && continue
261+
255262
idx = parameter_index(sys, p)
256-
if Symbolics.isarraysymbolic(p)
257-
idx === nothing || continue
258-
push!(other_array_parameters, p collect(p))
259-
elseif iscall(p) && operation(p) == getindex
260-
idx === nothing && continue
261-
# all of the scalarized variables are in `ps`
262-
all(x -> any(isequal(x), ps), collect(p))|| continue
263-
push!(other_array_parameters, p collect(p))
263+
idx isa Int && continue
264+
if idx isa ParameterIndex
265+
if idx.portion != SciMLStructures.Tunable()
266+
continue
267+
end
268+
idxs = vec(idx.idx)
269+
sz = size(idx.idx)
270+
else
271+
# idx === nothing
272+
idxs = map(Base.Fix1(parameter_index, sys), scal)
273+
if all(x -> x isa ParameterIndex && x.portion isa SciMLStructures.Tunable, idxs)
274+
idxs = map(x -> x.idx, idxs)
275+
end
276+
if !all(x -> x isa Int, idxs)
277+
other_array_parameters[p] = scal
278+
continue
279+
end
280+
281+
sz = size(idxs)
282+
if vec(idxs) == idxs[begin]:idxs[end]
283+
idxs = idxs[begin]:idxs[end]
284+
elseif vec(idxs) == idxs[begin]:-1:idxs[end]
285+
idxs = idxs[begin]:-1:idxs[end]
286+
end
287+
idxs = vec(idxs)
264288
end
289+
array_tunables[p] = (idxs, sz)
265290
end
266291
for (k, inds) in array_vars
267292
if inds == (inds′ = inds[1]:inds[end])
@@ -276,9 +301,10 @@ function wrap_array_vars(
276301
Let(
277302
vcat(
278303
[k :(view($(expr.args[uind].name), $v)) for (k, v) in array_vars],
279-
[k :(view($(expr.args[uind + 1].name), $v))
280-
for (k, v) in array_tunables],
281-
other_array_parameters
304+
[k :(reshape(view($(expr.args[uind + 1].name), $idxs), $sz))
305+
for (k, (idxs, sz)) in array_tunables],
306+
[k Code.MakeArray(v, symtype(k))
307+
for (k, v) in other_array_parameters]
282308
),
283309
expr.body,
284310
false
@@ -293,8 +319,10 @@ function wrap_array_vars(
293319
Let(
294320
vcat(
295321
[k :(view($(expr.args[uind].name), $v)) for (k, v) in array_vars],
296-
[k :(view($(expr.args[uind + 1].name), $v))
297-
for (k, v) in array_tunables]
322+
[k :(reshape(view($(expr.args[uind + 1].name), $idxs), $sz))
323+
for (k, (idxs, sz)) in array_tunables],
324+
[k Code.MakeArray(v, symtype(k))
325+
for (k, v) in other_array_parameters]
298326
),
299327
expr.body,
300328
false
@@ -309,8 +337,10 @@ function wrap_array_vars(
309337
vcat(
310338
[k :(view($(expr.args[uind + 1].name), $v))
311339
for (k, v) in array_vars],
312-
[k :(view($(expr.args[uind + 2].name), $v))
313-
for (k, v) in array_tunables]
340+
[k :(reshape(view($(expr.args[uind + 2].name), $idxs), $sz))
341+
for (k, (idxs, sz)) in array_tunables],
342+
[k Code.MakeArray(v, symtype(k))
343+
for (k, v) in other_array_parameters]
314344
),
315345
expr.body,
316346
false
@@ -499,15 +529,18 @@ function SymbolicIndexingInterface.is_parameter(sys::AbstractSystem, sym)
499529
return unwrap(sym) in 1:length(parameter_symbols(sys))
500530
end
501531
return any(isequal(sym), parameter_symbols(sys)) ||
502-
hasname(sym) && is_parameter(sys, getname(sym))
532+
hasname(sym) && !(iscall(sym) && operation(sym) == getindex) &&
533+
is_parameter(sys, getname(sym))
503534
end
504535

505536
function SymbolicIndexingInterface.is_parameter(sys::AbstractSystem, sym::Symbol)
506537
if has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing
507538
return is_parameter(ic, sym)
508539
end
509540

510-
named_parameters = [getname(sym) for sym in parameter_symbols(sys) if hasname(sym)]
541+
named_parameters = [getname(x)
542+
for x in parameter_symbols(sys)
543+
if hasname(x) && !(iscall(x) && operation(x) == getindex)]
511544
return any(isequal(sym), named_parameters) ||
512545
count(NAMESPACE_SEPARATOR, string(sym)) == 1 &&
513546
count(isequal(sym),
@@ -543,7 +576,7 @@ function SymbolicIndexingInterface.parameter_index(sys::AbstractSystem, sym)
543576
return sym
544577
end
545578
idx = findfirst(isequal(sym), parameter_symbols(sys))
546-
if idx === nothing && hasname(sym)
579+
if idx === nothing && hasname(sym) && !(iscall(sym) && operation(sym) == getindex)
547580
idx = parameter_index(sys, getname(sym))
548581
end
549582
return idx
@@ -559,13 +592,16 @@ function SymbolicIndexingInterface.parameter_index(sys::AbstractSystem, sym::Sym
559592
return idx
560593
end
561594
end
562-
idx = findfirst(isequal(sym), getname.(parameter_symbols(sys)))
595+
pnames = [getname(x)
596+
for x in parameter_symbols(sys)
597+
if hasname(x) && !(iscall(x) && operation(x) == getindex)]
598+
idx = findfirst(isequal(sym), pnames)
563599
if idx !== nothing
564600
return idx
565601
elseif count(NAMESPACE_SEPARATOR, string(sym)) == 1
566602
return findfirst(isequal(sym),
567603
Symbol.(
568-
nameof(sys), NAMESPACE_SEPARATOR_SYMBOL, getname.(parameter_symbols(sys))))
604+
nameof(sys), NAMESPACE_SEPARATOR_SYMBOL, pnames))
569605
end
570606
return nothing
571607
end

src/systems/callbacks.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -427,6 +427,7 @@ function compile_affect(eqs::Vector{Equation}, sys, dvs, ps; outputidxs = nothin
427427
update_inds = outputidxs
428428
end
429429

430+
_ps = ps
430431
ps = reorder_parameters(sys, ps)
431432
if checkvars
432433
u = map(x -> time_varying_as_func(value(x), sys), dvs)
@@ -440,7 +441,7 @@ function compile_affect(eqs::Vector{Equation}, sys, dvs, ps; outputidxs = nothin
440441
pre = get_preprocess_constants(rhss)
441442
rf_oop, rf_ip = build_function(rhss, u, p..., t; expression = Val{true},
442443
wrap_code = add_integrator_header(sys, integ, outvar) .∘
443-
wrap_array_vars(sys, rhss; dvs, ps),
444+
wrap_array_vars(sys, rhss; dvs, ps = _ps),
444445
outputidxs = update_inds,
445446
postprocess_fbody = pre,
446447
kwargs...)

test/odesystem.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1250,3 +1250,17 @@ end
12501250
prob = ODEProblem(ssys, [], (0.0, 1.0), [])
12511251
@test prob[x] == prob[y] == prob[z] == 1.0
12521252
end
1253+
1254+
@testset "Scalarized parameters in array functions" begin
1255+
@variables u(t)[1:2] x(t)[1:2] o(t)[1:2]
1256+
@parameters p[1:2, 1:2] [tunable = false]
1257+
@named sys = ODESystem(
1258+
[D(u) ~ (sum(u) + sum(x) + sum(p) + sum(o)) * x, o ~ prod(u) * x],
1259+
t, [u..., x..., o...], [p...])
1260+
sys1, = structural_simplify(sys, ([x...], []))
1261+
fn1, = ModelingToolkit.generate_function(sys1; expression = Val{false})
1262+
@test_nowarn fn1(ones(4), 2ones(2), 3ones(2, 2), 4.0)
1263+
sys2, = structural_simplify(sys, ([x...], []); split = false)
1264+
fn2, = ModelingToolkit.generate_function(sys2; expression = Val{false})
1265+
@test_nowarn fn2(ones(4), 2ones(6), 4.0)
1266+
end

0 commit comments

Comments
 (0)