Skip to content

Commit 6af4c99

Browse files
fix: better handling of (possibly scalarized) array parameters
1 parent a50f143 commit 6af4c99

File tree

2 files changed

+66
-24
lines changed

2 files changed

+66
-24
lines changed

src/systems/abstractsystem.jl

Lines changed: 52 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -242,25 +242,42 @@ function wrap_array_vars(
242242
# tunables are scalarized and concatenated, so we need to have assignments
243243
# for the non-scalarized versions
244244
array_tunables = Dict{Any, AbstractArray{Int}}()
245-
for p in ps
246-
idx = parameter_index(sys, p)
247-
idx isa ParameterIndex || continue
248-
idx.portion isa SciMLStructures.Tunable || continue
249-
idx.idx isa AbstractArray || continue
250-
array_tunables[p] = idx.idx
251-
end
252245
# Other parameters may be scalarized arrays but used in the vector form
253-
other_array_parameters = Assignment[]
246+
other_array_parameters = Dict{Any, Any}()
247+
254248
for p in ps
249+
p = unwrap(p)
250+
if iscall(p) && operation(p) == getindex
251+
p = arguments(p)[1]
252+
end
253+
symtype(p) <: AbstractArray && Symbolics.shape(p) != Symbolics.Unknown() || continue
254+
scal = collect(p)
255+
# all scalarized variables are in `ps`
256+
all(x -> any(isequal(x), ps), scal) || continue
257+
(haskey(array_tunables, p) || haskey(other_array_parameters, p)) && continue
258+
255259
idx = parameter_index(sys, p)
256-
if Symbolics.isarraysymbolic(p)
257-
idx === nothing || continue
258-
push!(other_array_parameters, p collect(p))
259-
elseif iscall(p) && operation(p) == getindex
260-
idx === nothing && continue
261-
# all of the scalarized variables are in `ps`
262-
all(x -> any(isequal(x), ps), collect(p))|| continue
263-
push!(other_array_parameters, p collect(p))
260+
if idx === nothing
261+
idxs = map(Base.Fix1(parameter_index, sys), scal)
262+
if all(x -> x isa ParameterIndex && x.portion isa SciMLStructures.Tunable, idxs)
263+
idxs = map(x -> x.idx, idxs)
264+
end
265+
if all(x -> x isa Int, idxs)
266+
if vec(idxs) == idxs[begin]:idxs[end]
267+
idxs = reshape(idxs[begin]:idxs[end], size(idxs))
268+
elseif vec(idxs) == idxs[begin]:-1:idxs[end]
269+
idxs = reshape(idxs[begin]:-1:idxs[end], size(idxs))
270+
end
271+
array_tunables[p] = idxs
272+
else
273+
other_array_parameters[p] = scal
274+
end
275+
elseif idx isa Int
276+
continue
277+
elseif idx.portion != SciMLStructures.Tunable()
278+
other_array_parameters[p] = scal
279+
else
280+
array_tunables[p] = idx.idx
264281
end
265282
end
266283
for (k, inds) in array_vars
@@ -278,7 +295,8 @@ function wrap_array_vars(
278295
[k :(view($(expr.args[uind].name), $v)) for (k, v) in array_vars],
279296
[k :(view($(expr.args[uind + 1].name), $v))
280297
for (k, v) in array_tunables],
281-
other_array_parameters
298+
[k Code.MakeArray(v, typeof(v))
299+
for (k, v) in other_array_parameters]
282300
),
283301
expr.body,
284302
false
@@ -294,7 +312,9 @@ function wrap_array_vars(
294312
vcat(
295313
[k :(view($(expr.args[uind].name), $v)) for (k, v) in array_vars],
296314
[k :(view($(expr.args[uind + 1].name), $v))
297-
for (k, v) in array_tunables]
315+
for (k, v) in array_tunables],
316+
[k Code.MakeArray(v, typeof(v))
317+
for (k, v) in other_array_parameters]
298318
),
299319
expr.body,
300320
false
@@ -310,7 +330,9 @@ function wrap_array_vars(
310330
[k :(view($(expr.args[uind + 1].name), $v))
311331
for (k, v) in array_vars],
312332
[k :(view($(expr.args[uind + 2].name), $v))
313-
for (k, v) in array_tunables]
333+
for (k, v) in array_tunables],
334+
[k Code.MakeArray(v, typeof(v))
335+
for (k, v) in other_array_parameters]
314336
),
315337
expr.body,
316338
false
@@ -499,15 +521,18 @@ function SymbolicIndexingInterface.is_parameter(sys::AbstractSystem, sym)
499521
return unwrap(sym) in 1:length(parameter_symbols(sys))
500522
end
501523
return any(isequal(sym), parameter_symbols(sys)) ||
502-
hasname(sym) && is_parameter(sys, getname(sym))
524+
hasname(sym) && !(iscall(sym) && operation(sym) == getindex) &&
525+
is_parameter(sys, getname(sym))
503526
end
504527

505528
function SymbolicIndexingInterface.is_parameter(sys::AbstractSystem, sym::Symbol)
506529
if has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing
507530
return is_parameter(ic, sym)
508531
end
509532

510-
named_parameters = [getname(sym) for sym in parameter_symbols(sys) if hasname(sym)]
533+
named_parameters = [getname(x)
534+
for x in parameter_symbols(sys)
535+
if hasname(x) && !(iscall(x) && operation(x) == getindex)]
511536
return any(isequal(sym), named_parameters) ||
512537
count(NAMESPACE_SEPARATOR, string(sym)) == 1 &&
513538
count(isequal(sym),
@@ -543,7 +568,7 @@ function SymbolicIndexingInterface.parameter_index(sys::AbstractSystem, sym)
543568
return sym
544569
end
545570
idx = findfirst(isequal(sym), parameter_symbols(sys))
546-
if idx === nothing && hasname(sym)
571+
if idx === nothing && hasname(sym) && !(iscall(sym) && operation(sym) == getindex)
547572
idx = parameter_index(sys, getname(sym))
548573
end
549574
return idx
@@ -559,13 +584,16 @@ function SymbolicIndexingInterface.parameter_index(sys::AbstractSystem, sym::Sym
559584
return idx
560585
end
561586
end
562-
idx = findfirst(isequal(sym), getname.(parameter_symbols(sys)))
587+
pnames = [getname(x)
588+
for x in parameter_symbols(sys)
589+
if hasname(x) && !(iscall(x) && operation(x) == getindex)]
590+
idx = findfirst(isequal(sym), pnames)
563591
if idx !== nothing
564592
return idx
565593
elseif count(NAMESPACE_SEPARATOR, string(sym)) == 1
566594
return findfirst(isequal(sym),
567595
Symbol.(
568-
nameof(sys), NAMESPACE_SEPARATOR_SYMBOL, getname.(parameter_symbols(sys))))
596+
nameof(sys), NAMESPACE_SEPARATOR_SYMBOL, pnames))
569597
end
570598
return nothing
571599
end

test/odesystem.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1250,3 +1250,17 @@ end
12501250
prob = ODEProblem(ssys, [], (0.0, 1.0), [])
12511251
@test prob[x] == prob[y] == prob[z] == 1.0
12521252
end
1253+
1254+
@testset "Scalarized parameters in array functions" begin
1255+
@variables u(t)[1:2] x(t)[1:2] o(t)[1:2]
1256+
@parameters p[1:2, 1:2] [tunable = false]
1257+
@named sys = ODESystem(
1258+
[D(u) ~ (sum(u) + sum(x) + sum(p) + sum(o)) * x, o ~ prod(u) * x],
1259+
t, [u..., x..., o...], [p...])
1260+
sys1, = structural_simplify(sys, ([x...], []))
1261+
fn1, = ModelingToolkit.generate_function(sys1; expression = Val{false})
1262+
@test_nowarn fn1(ones(4), 2ones(2), 3ones(2, 2), 4.0)
1263+
sys2, = structural_simplify(sys, ([x...], []); split = false)
1264+
fn2, = ModelingToolkit.generate_function(sys2; expression = Val{false})
1265+
@test_nowarn fn2(ones(4), 2ones(6), 4.0)
1266+
end

0 commit comments

Comments
 (0)