@@ -223,15 +223,70 @@ function wrap_assignments(isscalar, assignments; let_block = false)
223
223
end
224
224
end
225
225
226
- function wrap_array_vars (sys:: AbstractSystem , exprs; dvs = unknowns (sys))
226
+ function wrap_array_vars (
227
+ sys:: AbstractSystem , exprs; dvs = unknowns (sys), ps = parameters (sys))
227
228
isscalar = ! (exprs isa AbstractArray)
228
229
array_vars = Dict {Any, AbstractArray{Int}} ()
229
- for (j, x) in enumerate (dvs)
230
- if iscall (x) && operation (x) == getindex
231
- arg = arguments (x)[1 ]
232
- inds = get! (() -> Int[], array_vars, arg)
233
- push! (inds, j)
230
+ if dvs != = nothing
231
+ for (j, x) in enumerate (dvs)
232
+ if iscall (x) && operation (x) == getindex
233
+ arg = arguments (x)[1 ]
234
+ inds = get! (() -> Int[], array_vars, arg)
235
+ push! (inds, j)
236
+ end
237
+ end
238
+ uind = 1
239
+ else
240
+ uind = 0
241
+ end
242
+ # tunables are scalarized and concatenated, so we need to have assignments
243
+ # for the non-scalarized versions
244
+ array_tunables = Dict{Any, Tuple{AbstractArray{Int}, Tuple{Vararg{Int}}}}()
245
+ # Other parameters may be scalarized arrays but used in the vector form
246
+ other_array_parameters = Dict {Any, Any} ()
247
+
248
+ if ps isa Tuple && eltype (ps) <: AbstractArray
249
+ ps = Iterators. flatten (ps)
250
+ end
251
+ for p in ps
252
+ p = unwrap (p)
253
+ if iscall (p) && operation (p) == getindex
254
+ p = arguments (p)[1 ]
255
+ end
256
+ symtype (p) <: AbstractArray && Symbolics. shape (p) != Symbolics. Unknown () || continue
257
+ scal = collect (p)
258
+ # all scalarized variables are in `ps`
259
+ any (isequal (p), ps) || all (x -> any (isequal (x), ps), scal) || continue
260
+ (haskey (array_tunables, p) || haskey (other_array_parameters, p)) && continue
261
+
262
+ idx = parameter_index (sys, p)
263
+ idx isa Int && continue
264
+ if idx isa ParameterIndex
265
+ if idx. portion != SciMLStructures. Tunable ()
266
+ continue
267
+ end
268
+ idxs = vec (idx. idx)
269
+ sz = size (idx. idx)
270
+ else
271
+ # idx === nothing
272
+ idxs = map (Base. Fix1 (parameter_index, sys), scal)
273
+ if all (x -> x isa ParameterIndex && x. portion isa SciMLStructures. Tunable, idxs)
274
+ idxs = map (x -> x. idx, idxs)
275
+ end
276
+ if ! all (x -> x isa Int, idxs)
277
+ other_array_parameters[p] = scal
278
+ continue
279
+ end
280
+
281
+ sz = size (idxs)
282
+ if vec (idxs) == idxs[begin ]: idxs[end ]
283
+ idxs = idxs[begin ]: idxs[end ]
284
+ elseif vec (idxs) == idxs[begin ]: - 1 : idxs[end ]
285
+ idxs = idxs[begin ]: - 1 : idxs[end ]
286
+ end
287
+ idxs = vec (idxs)
234
288
end
289
+ array_tunables[p] = (idxs, sz)
235
290
end
236
291
for (k, inds) in array_vars
237
292
if inds == (inds′ = inds[1 ]: inds[end ])
@@ -244,7 +299,13 @@ function wrap_array_vars(sys::AbstractSystem, exprs; dvs = unknowns(sys))
244
299
expr. args,
245
300
[],
246
301
Let (
247
- [k ← :(view ($ (expr. args[1 ]. name), $ v)) for (k, v) in array_vars],
302
+ vcat (
303
+ [k ← :(view ($ (expr. args[uind]. name), $ v)) for (k, v) in array_vars],
304
+ [k ← :(reshape (view ($ (expr. args[uind + 1 ]. name), $ idxs), $ sz))
305
+ for (k, (idxs, sz)) in array_tunables],
306
+ [k ← Code. MakeArray (v, symtype (k))
307
+ for (k, v) in other_array_parameters]
308
+ ),
248
309
expr. body,
249
310
false
250
311
)
@@ -256,7 +317,13 @@ function wrap_array_vars(sys::AbstractSystem, exprs; dvs = unknowns(sys))
256
317
expr. args,
257
318
[],
258
319
Let (
259
- [k ← :(view ($ (expr. args[1 ]. name), $ v)) for (k, v) in array_vars],
320
+ vcat (
321
+ [k ← :(view ($ (expr. args[uind]. name), $ v)) for (k, v) in array_vars],
322
+ [k ← :(reshape (view ($ (expr. args[uind + 1 ]. name), $ idxs), $ sz))
323
+ for (k, (idxs, sz)) in array_tunables],
324
+ [k ← Code. MakeArray (v, symtype (k))
325
+ for (k, v) in other_array_parameters]
326
+ ),
260
327
expr. body,
261
328
false
262
329
)
@@ -267,7 +334,14 @@ function wrap_array_vars(sys::AbstractSystem, exprs; dvs = unknowns(sys))
267
334
expr. args,
268
335
[],
269
336
Let (
270
- [k ← :(view ($ (expr. args[2 ]. name), $ v)) for (k, v) in array_vars],
337
+ vcat (
338
+ [k ← :(view ($ (expr. args[uind + 1 ]. name), $ v))
339
+ for (k, v) in array_vars],
340
+ [k ← :(reshape (view ($ (expr. args[uind + 2 ]. name), $ idxs), $ sz))
341
+ for (k, (idxs, sz)) in array_tunables],
342
+ [k ← Code. MakeArray (v, symtype (k))
343
+ for (k, v) in other_array_parameters]
344
+ ),
271
345
expr. body,
272
346
false
273
347
)
@@ -455,15 +529,18 @@ function SymbolicIndexingInterface.is_parameter(sys::AbstractSystem, sym)
455
529
return unwrap (sym) in 1 : length (parameter_symbols (sys))
456
530
end
457
531
return any (isequal (sym), parameter_symbols (sys)) ||
458
- hasname (sym) && is_parameter (sys, getname (sym))
532
+ hasname (sym) && ! (iscall (sym) && operation (sym) == getindex) &&
533
+ is_parameter (sys, getname (sym))
459
534
end
460
535
461
536
function SymbolicIndexingInterface. is_parameter (sys:: AbstractSystem , sym:: Symbol )
462
537
if has_index_cache (sys) && (ic = get_index_cache (sys)) != = nothing
463
538
return is_parameter (ic, sym)
464
539
end
465
540
466
- named_parameters = [getname (sym) for sym in parameter_symbols (sys) if hasname (sym)]
541
+ named_parameters = [getname (x)
542
+ for x in parameter_symbols (sys)
543
+ if hasname (x) && ! (iscall (x) && operation (x) == getindex)]
467
544
return any (isequal (sym), named_parameters) ||
468
545
count (NAMESPACE_SEPARATOR, string (sym)) == 1 &&
469
546
count (isequal (sym),
@@ -499,7 +576,7 @@ function SymbolicIndexingInterface.parameter_index(sys::AbstractSystem, sym)
499
576
return sym
500
577
end
501
578
idx = findfirst (isequal (sym), parameter_symbols (sys))
502
- if idx === nothing && hasname (sym)
579
+ if idx === nothing && hasname (sym) && ! ( iscall (sym) && operation (sym) == getindex)
503
580
idx = parameter_index (sys, getname (sym))
504
581
end
505
582
return idx
@@ -515,13 +592,16 @@ function SymbolicIndexingInterface.parameter_index(sys::AbstractSystem, sym::Sym
515
592
return idx
516
593
end
517
594
end
518
- idx = findfirst (isequal (sym), getname .(parameter_symbols (sys)))
595
+ pnames = [getname (x)
596
+ for x in parameter_symbols (sys)
597
+ if hasname (x) && ! (iscall (x) && operation (x) == getindex)]
598
+ idx = findfirst (isequal (sym), pnames)
519
599
if idx != = nothing
520
600
return idx
521
601
elseif count (NAMESPACE_SEPARATOR, string (sym)) == 1
522
602
return findfirst (isequal (sym),
523
603
Symbol .(
524
- nameof (sys), NAMESPACE_SEPARATOR_SYMBOL, getname .( parameter_symbols (sys)) ))
604
+ nameof (sys), NAMESPACE_SEPARATOR_SYMBOL, pnames ))
525
605
end
526
606
return nothing
527
607
end
0 commit comments