@@ -241,27 +241,52 @@ function wrap_array_vars(
241
241
end
242
242
# tunables are scalarized and concatenated, so we need to have assignments
243
243
# for the non-scalarized versions
244
- array_tunables = Dict {Any, AbstractArray{Int}} ()
245
- for p in ps
246
- idx = parameter_index (sys, p)
247
- idx isa ParameterIndex || continue
248
- idx. portion isa SciMLStructures. Tunable || continue
249
- idx. idx isa AbstractArray || continue
250
- array_tunables[p] = idx. idx
251
- end
244
+ array_tunables = Dict{Any, Tuple{AbstractArray{Int}, Tuple{Vararg{Int}}}}()
252
245
# Other parameters may be scalarized arrays but used in the vector form
253
- other_array_parameters = Assignment[]
246
+ other_array_parameters = Dict {Any, Any} ()
247
+
248
+ if ps isa Tuple && eltype (ps) <: AbstractArray
249
+ ps = Iterators. flatten (ps)
250
+ end
254
251
for p in ps
252
+ p = unwrap (p)
253
+ if iscall (p) && operation (p) == getindex
254
+ p = arguments (p)[1 ]
255
+ end
256
+ symtype (p) <: AbstractArray && Symbolics. shape (p) != Symbolics. Unknown () || continue
257
+ scal = collect (p)
258
+ # all scalarized variables are in `ps`
259
+ any (isequal (p), ps) || all (x -> any (isequal (x), ps), scal) || continue
260
+ (haskey (array_tunables, p) || haskey (other_array_parameters, p)) && continue
261
+
255
262
idx = parameter_index (sys, p)
256
- if Symbolics. isarraysymbolic (p)
257
- idx === nothing || continue
258
- push! (other_array_parameters, p ← collect (p))
259
- elseif iscall (p) && operation (p) == getindex
260
- idx === nothing && continue
261
- # all of the scalarized variables are in `ps`
262
- all (x -> any (isequal (x), ps), collect (p))|| continue
263
- push! (other_array_parameters, p ← collect (p))
263
+ idx isa Int && continue
264
+ if idx isa ParameterIndex
265
+ if idx. portion != SciMLStructures. Tunable ()
266
+ continue
267
+ end
268
+ idxs = vec (idx. idx)
269
+ sz = size (idx. idx)
270
+ else
271
+ # idx === nothing
272
+ idxs = map (Base. Fix1 (parameter_index, sys), scal)
273
+ if all (x -> x isa ParameterIndex && x. portion isa SciMLStructures. Tunable, idxs)
274
+ idxs = map (x -> x. idx, idxs)
275
+ end
276
+ if ! all (x -> x isa Int, idxs)
277
+ other_array_parameters[p] = scal
278
+ continue
279
+ end
280
+
281
+ sz = size (idxs)
282
+ if vec (idxs) == idxs[begin ]: idxs[end ]
283
+ idxs = idxs[begin ]: idxs[end ]
284
+ elseif vec (idxs) == idxs[begin ]: - 1 : idxs[end ]
285
+ idxs = idxs[begin ]: - 1 : idxs[end ]
286
+ end
287
+ idxs = vec (idxs)
264
288
end
289
+ array_tunables[p] = (idxs, sz)
265
290
end
266
291
for (k, inds) in array_vars
267
292
if inds == (inds′ = inds[1 ]: inds[end ])
@@ -276,9 +301,10 @@ function wrap_array_vars(
276
301
Let (
277
302
vcat (
278
303
[k ← :(view ($ (expr. args[uind]. name), $ v)) for (k, v) in array_vars],
279
- [k ← :(view ($ (expr. args[uind + 1 ]. name), $ v))
280
- for (k, v) in array_tunables],
281
- other_array_parameters
304
+ [k ← :(reshape (view ($ (expr. args[uind + 1 ]. name), $ idxs), $ sz))
305
+ for (k, (idxs, sz)) in array_tunables],
306
+ [k ← Code. MakeArray (v, symtype (k))
307
+ for (k, v) in other_array_parameters]
282
308
),
283
309
expr. body,
284
310
false
@@ -293,8 +319,10 @@ function wrap_array_vars(
293
319
Let (
294
320
vcat (
295
321
[k ← :(view ($ (expr. args[uind]. name), $ v)) for (k, v) in array_vars],
296
- [k ← :(view ($ (expr. args[uind + 1 ]. name), $ v))
297
- for (k, v) in array_tunables]
322
+ [k ← :(reshape (view ($ (expr. args[uind + 1 ]. name), $ idxs), $ sz))
323
+ for (k, (idxs, sz)) in array_tunables],
324
+ [k ← Code. MakeArray (v, symtype (k))
325
+ for (k, v) in other_array_parameters]
298
326
),
299
327
expr. body,
300
328
false
@@ -309,8 +337,10 @@ function wrap_array_vars(
309
337
vcat (
310
338
[k ← :(view ($ (expr. args[uind + 1 ]. name), $ v))
311
339
for (k, v) in array_vars],
312
- [k ← :(view ($ (expr. args[uind + 2 ]. name), $ v))
313
- for (k, v) in array_tunables]
340
+ [k ← :(reshape (view ($ (expr. args[uind + 2 ]. name), $ idxs), $ sz))
341
+ for (k, (idxs, sz)) in array_tunables],
342
+ [k ← Code. MakeArray (v, symtype (k))
343
+ for (k, v) in other_array_parameters]
314
344
),
315
345
expr. body,
316
346
false
@@ -499,15 +529,18 @@ function SymbolicIndexingInterface.is_parameter(sys::AbstractSystem, sym)
499
529
return unwrap (sym) in 1 : length (parameter_symbols (sys))
500
530
end
501
531
return any (isequal (sym), parameter_symbols (sys)) ||
502
- hasname (sym) && is_parameter (sys, getname (sym))
532
+ hasname (sym) && ! (iscall (sym) && operation (sym) == getindex) &&
533
+ is_parameter (sys, getname (sym))
503
534
end
504
535
505
536
function SymbolicIndexingInterface. is_parameter (sys:: AbstractSystem , sym:: Symbol )
506
537
if has_index_cache (sys) && (ic = get_index_cache (sys)) != = nothing
507
538
return is_parameter (ic, sym)
508
539
end
509
540
510
- named_parameters = [getname (sym) for sym in parameter_symbols (sys) if hasname (sym)]
541
+ named_parameters = [getname (x)
542
+ for x in parameter_symbols (sys)
543
+ if hasname (x) && ! (iscall (x) && operation (x) == getindex)]
511
544
return any (isequal (sym), named_parameters) ||
512
545
count (NAMESPACE_SEPARATOR, string (sym)) == 1 &&
513
546
count (isequal (sym),
@@ -543,7 +576,7 @@ function SymbolicIndexingInterface.parameter_index(sys::AbstractSystem, sym)
543
576
return sym
544
577
end
545
578
idx = findfirst (isequal (sym), parameter_symbols (sys))
546
- if idx === nothing && hasname (sym)
579
+ if idx === nothing && hasname (sym) && ! ( iscall (sym) && operation (sym) == getindex)
547
580
idx = parameter_index (sys, getname (sym))
548
581
end
549
582
return idx
@@ -559,13 +592,16 @@ function SymbolicIndexingInterface.parameter_index(sys::AbstractSystem, sym::Sym
559
592
return idx
560
593
end
561
594
end
562
- idx = findfirst (isequal (sym), getname .(parameter_symbols (sys)))
595
+ pnames = [getname (x)
596
+ for x in parameter_symbols (sys)
597
+ if hasname (x) && ! (iscall (x) && operation (x) == getindex)]
598
+ idx = findfirst (isequal (sym), pnames)
563
599
if idx != = nothing
564
600
return idx
565
601
elseif count (NAMESPACE_SEPARATOR, string (sym)) == 1
566
602
return findfirst (isequal (sym),
567
603
Symbol .(
568
- nameof (sys), NAMESPACE_SEPARATOR_SYMBOL, getname .( parameter_symbols (sys)) ))
604
+ nameof (sys), NAMESPACE_SEPARATOR_SYMBOL, pnames ))
569
605
end
570
606
return nothing
571
607
end
0 commit comments