Skip to content

Commit 8c4132c

Browse files
Merge pull request #1012 from AayushSabharwal/as/oop-init-and-fixes
feat: allow `update_initializeprob!` to be oop, fix `SCCNonlinearProblem` dropping `u0::SVector` type
2 parents a2c205b + 5ba3d32 commit 8c4132c

File tree

7 files changed

+172
-37
lines changed

7 files changed

+172
-37
lines changed

src/SciMLBase.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ import ADTypes: ADTypes, AbstractADType
2525
import Accessors: @set, @reset, @delete, @insert
2626
using Moshi.Data: @data
2727
using Moshi.Match: @match
28+
import StaticArraysCore
2829

2930
using Reexport
3031
using SciMLOperators

src/initialization.jl

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33
44
A collection of all the data required for `OverrideInit`.
55
"""
6-
struct OverrideInitData{IProb, UIProb, IProbMap, IProbPmap, M}
6+
struct OverrideInitData{
7+
IProb, UIProb, IProbMap, IProbPmap, M, OOP <: Union{Val{true}, Val{false}}}
78
"""
89
The `AbstractNonlinearProblem` to solve for initialization.
910
"""
@@ -34,19 +35,26 @@ struct OverrideInitData{IProb, UIProb, IProbMap, IProbPmap, M}
3435
Additional metadata required by the creator of the initialization.
3536
"""
3637
metadata::M
38+
"""
39+
If this flag is `Val{true}`, `update_initializeprob!` is treated as an out-of-place
40+
function which returns the updated `initializeprob`.
41+
"""
42+
is_update_oop::OOP
3743

3844
function OverrideInitData(initprob::I, update_initprob!::J, initprobmap::K,
39-
initprobpmap::L, metadata::M) where {I, J, K, L, M}
45+
initprobpmap::L, metadata::M, is_update_oop::O) where {I, J, K, L, M, O}
4046
@assert initprob isa
4147
Union{SCCNonlinearProblem, NonlinearProblem, NonlinearLeastSquaresProblem}
42-
return new{I, J, K, L, M}(
43-
initprob, update_initprob!, initprobmap, initprobpmap, metadata)
48+
return new{I, J, K, L, M, O}(
49+
initprob, update_initprob!, initprobmap, initprobpmap, metadata, is_update_oop)
4450
end
4551
end
4652

4753
function OverrideInitData(
48-
initprob, update_initprob!, initprobmap, initprobpmap; metadata = nothing)
49-
OverrideInitData(initprob, update_initprob!, initprobmap, initprobpmap, metadata)
54+
initprob, update_initprob!, initprobmap, initprobpmap;
55+
metadata = nothing, is_update_oop = Val(false))
56+
OverrideInitData(
57+
initprob, update_initprob!, initprobmap, initprobpmap, metadata, is_update_oop)
5058
end
5159

5260
"""
@@ -244,7 +252,11 @@ function get_initial_values(prob, valp, f, alg::OverrideInit,
244252
initprob = initdata.initializeprob
245253

246254
if initdata.update_initializeprob! !== nothing
247-
initdata.update_initializeprob!(initprob, valp)
255+
if initdata.is_update_oop === Val(true)
256+
initprob = initdata.update_initializeprob!(initprob, valp)
257+
else
258+
initdata.update_initializeprob!(initprob, valp)
259+
end
248260
end
249261

250262
if is_trivial_initialization(initdata)

src/problems/nonlinear_problems.jl

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -484,27 +484,37 @@ Note that this example aliases the parameters together for a memory-reduced repr
484484
* `probs`: the collection of problems to solve
485485
* `explictfuns!`: the explicit functions for mutating the parameter set
486486
"""
487-
mutable struct SCCNonlinearProblem{uType, iip, P, E, F <: NonlinearFunction{iip}, Par} <:
487+
mutable struct SCCNonlinearProblem{uType, iip, P, E, F <: NonlinearFunction{iip},
488+
Par, Palias <: Union{Val{true}, Val{false}}} <:
488489
AbstractNonlinearProblem{uType, iip}
489490
probs::P
490491
explicitfuns!::E
491492
# NonlinearFunction with `f = Returns(nothing)`
492493
f::F
493494
p::Par
494-
parameters_alias::Bool
495+
parameters_alias::Palias
495496

496497
function SCCNonlinearProblem{P, E, F, Par}(probs::P, funs::E, f::F, pobj::Par,
497-
alias::Bool) where {P, E, F <: NonlinearFunction, Par}
498+
alias::Palias) where {P, E, F <: NonlinearFunction, Par, Palias}
499+
init = state_values(first(probs))
500+
if ArrayInterface.ismutable(init)
501+
init = similar(init, 0)
502+
else
503+
init = StaticArraysCore.similar_type(init, StaticArraysCore.Size(0))()
504+
end
498505
u0 = mapreduce(
499-
state_values, vcat, probs; init = similar(state_values(first(probs)), 0))
506+
state_values, vcat, probs; init = init)
500507
uType = typeof(u0)
501-
new{uType, false, P, E, F, Par}(probs, funs, f, pobj, alias)
508+
new{uType, false, P, E, F, Par, Palias}(probs, funs, f, pobj, alias)
502509
end
503510
end
504511

505512
function SCCNonlinearProblem(probs, explicitfuns!, parameter_object = nothing,
506-
parameters_alias = false; kwargs...)
513+
parameters_alias::Union{Bool, Val{true}, Val{false}} = Val(false); kwargs...)
507514
f = NonlinearFunction{false}(Returns(nothing); kwargs...)
515+
if parameters_alias isa Bool
516+
parameters_alias = Val(parameters_alias)
517+
end
508518
return SCCNonlinearProblem{typeof(probs), typeof(explicitfuns!),
509519
typeof(f), typeof(parameter_object)}(
510520
probs, explicitfuns!, f, parameter_object, parameters_alias)
@@ -526,8 +536,13 @@ function SymbolicIndexingInterface.parameter_values(prob::SCCNonlinearProblem)
526536
prob.p
527537
end
528538
function SymbolicIndexingInterface.state_values(prob::SCCNonlinearProblem)
529-
mapreduce(
530-
state_values, vcat, prob.probs; init = similar(state_values(first(prob.probs)), 0))
539+
init = state_values(first(prob.probs))
540+
if ArrayInterface.ismutable(init)
541+
init = similar(init, 0)
542+
else
543+
init = StaticArraysCore.similar_type(init, StaticArraysCore.Size(0))()
544+
end
545+
mapreduce(state_values, vcat, prob.probs; init)
531546
end
532547

533548
function SymbolicIndexingInterface.set_state!(prob::SCCNonlinearProblem, val, idx)
@@ -542,7 +557,7 @@ end
542557
function SymbolicIndexingInterface.set_parameter!(prob::SCCNonlinearProblem, val, idx)
543558
if prob.p !== nothing
544559
set_parameter!(prob.p, val, idx)
545-
prob.parameters_alias && return
560+
prob.parameters_alias === Val(true) && return
546561
end
547562
for scc in prob.probs
548563
is_parameter(scc, idx) || continue

src/remake.jl

Lines changed: 63 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -803,6 +803,61 @@ function remake(prob::NonlinearLeastSquaresProblem; f = missing, u0 = missing, p
803803
return prob
804804
end
805805

806+
function scc_update_subproblems(probs::Vector, newu0, newp, parameters_alias)
807+
offset = Ref(0)
808+
return map(probs) do subprob
809+
# N should be inferred if `prob` is type-stable and `subprob.u0 isa StaticArray`
810+
N = length(state_values(subprob))
811+
if ArrayInterface.ismutable(newu0)
812+
_u0 = newu0[(offset[] + 1):(offset[] + N)]
813+
else
814+
_u0 = StaticArraysCore.similar_type(
815+
newu0, StaticArraysCore.Size(N))(newu0[(offset[] + 1):(offset[] + N)])
816+
end
817+
subprob = if parameters_alias === Val(true)
818+
remake(subprob; u0 = _u0, p = newp)
819+
else
820+
remake(subprob; u0 = _u0)
821+
end
822+
offset[] += length(state_values(subprob))
823+
return subprob
824+
end
825+
end
826+
827+
@generated function scc_update_subproblems(probs::Tuple, newu0, newp, parameters_alias)
828+
function get_expr(i::Int)
829+
subprob_name = Symbol(:subprob, i)
830+
quote
831+
$subprob_name = probs[$i]
832+
# N should be inferred if `prob` is type-stable and `subprob.u0 isa StaticArray`
833+
N = length(state_values($subprob_name))
834+
if ArrayInterface.ismutable(newu0)
835+
_u0 = newu0[(offset + 1):(offset + N)]
836+
else
837+
_u0 = StaticArraysCore.similar_type(
838+
newu0, StaticArraysCore.Size(N))(newu0[(offset + 1):(offset + N)])
839+
end
840+
$subprob_name = if parameters_alias === Val(true)
841+
remake($subprob_name; u0 = _u0, p = newp)
842+
else
843+
remake($subprob_name; u0 = _u0)
844+
end
845+
offset += N
846+
end, subprob_name
847+
end
848+
expr = quote
849+
offset = 0
850+
end
851+
subprob_names = []
852+
for i in 1:fieldcount(probs)
853+
subexpr, spname = get_expr(i)
854+
push!(expr.args, subexpr)
855+
push!(subprob_names, spname)
856+
end
857+
push!(expr.args, Expr(:tuple, subprob_names...))
858+
return expr
859+
end
860+
806861
"""
807862
remake(prob::SCCNonlinearProblem; u0 = missing, p = missing, probs = missing,
808863
parameters_alias = prob.parameters_alias, sys = missing, explicitfuns! = missing)
@@ -818,7 +873,10 @@ override the values in `probs`. `sys` is the index provider for the full system.
818873
function remake(prob::SCCNonlinearProblem; u0 = missing, p = missing, probs = missing,
819874
parameters_alias = prob.parameters_alias, f = missing, sys = missing,
820875
interpret_symbolicmap = true, use_defaults = false, explicitfuns! = missing)
821-
if p !== missing && !parameters_alias && probs === missing
876+
if parameters_alias isa Bool
877+
parameters_alias = Val(parameters_alias)
878+
end
879+
if p !== missing && parameters_alias === Val(false) && probs === missing
822880
throw(ArgumentError("`parameters_alias` is `false` for the given `SCCNonlinearProblem`. Please provide the subproblems using the keyword `probs` with the parameters updated appropriately in each."))
823881
end
824882
newu0, newp = updated_u0_p(prob, u0, p; interpret_symbolicmap, use_defaults)
@@ -831,28 +889,14 @@ function remake(prob::SCCNonlinearProblem; u0 = missing, p = missing, probs = mi
831889
if sys === missing
832890
sys = prob.f.sys
833891
end
834-
offset = 0
835-
if u0 !== missing || p !== missing && parameters_alias
836-
probs = map(probs) do subprob
837-
subprob = if parameters_alias
838-
remake(subprob;
839-
u0 = newu0[(offset + 1):(offset + length(state_values(subprob)))],
840-
p = newp)
841-
else
842-
remake(subprob;
843-
u0 = newu0[(offset + 1):(offset + length(state_values(subprob)))])
844-
end
845-
offset += length(state_values(subprob))
846-
return subprob
847-
end
892+
if u0 !== missing || p !== missing && parameters_alias === Val(true)
893+
probs = scc_update_subproblems(probs, newu0, newp, parameters_alias)
848894
end
849895
f = coalesce(f, prob.f)
850896
f = remake(f; sys)
851-
props = getproperties(f)
852-
props = @delete props.f
853897

854-
return SCCNonlinearProblem(
855-
probs, explicitfuns!, newp, parameters_alias; props...)
898+
return SCCNonlinearProblem{typeof(probs), typeof(explicitfuns!), typeof(f), typeof(newp)}(
899+
probs, explicitfuns!, f, newp, parameters_alias)
856900
end
857901

858902
function varmap_has_var(varmap, var)

test/downstream/modelingtoolkit_remake.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -330,7 +330,7 @@ end
330330
newp = remake_buffer(sccprob.f.sys, sccprob.p, [σ], [3.0])
331331
sccprob4 = remake(sccprob; parameters_alias = false, p = newp,
332332
probs = [remake(sccprob.probs[1]; p = deepcopy(newp)), sccprob.probs[2]])
333-
@test !sccprob4.parameters_alias
333+
@test sccprob4.parameters_alias === Val(false)
334334
@test sccprob4.p !== sccprob4.probs[1].p
335335
@test sccprob4.p !== sccprob4.probs[2].p
336336
end

test/initialization.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,24 @@ end
240240
@test success
241241
end
242242

243+
@testset "`is_update_oop` flag" begin
244+
initprob = remake(initprob; u0 = ones(2), p = ones(1))
245+
update_initializeprob = function (initprob, valp)
246+
return remake(initprob; p = [valp.u[1]])
247+
end
248+
initdata = SciMLBase.OverrideInitData(initprob, update_initializeprob, initprobmap,
249+
initprobpmap; is_update_oop = Val(true))
250+
fn = ODEFunction(rhs2; initialization_data = initdata)
251+
prob = ODEProblem(fn, [2.0, 0.0], (0.0, 1.0), 0.0)
252+
integ = init(prob; initializealg = NoInit())
253+
u0, p, success = SciMLBase.get_initial_values(
254+
prob, integ, fn, SciMLBase.OverrideInit(), Val(false);
255+
nlsolve_alg = NewtonRaphson(), abstol, reltol)
256+
@test u0 [2.0, 2.0]
257+
@test p 1.0
258+
@test success
259+
end
260+
243261
@testset "Solves without `initializeprobmap`" begin
244262
initdata = SciMLBase.@set initialization_data.initializeprobmap = nothing
245263
fn = ODEFunction(rhs2; initialization_data = initdata)

test/problem_building_test.jl

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using Test, SciMLBase, SymbolicIndexingInterface, Accessors
1+
using Test, SciMLBase, SymbolicIndexingInterface, Accessors, StaticArrays
22

33
function simplependulum!(du, u, p, t)
44
θ = u[1]
@@ -98,3 +98,48 @@ let
9898
@test dprob8.u0 === u₀
9999
@test dprob8.tspan === tspan
100100
end
101+
102+
@testset "SCCNonlinearProblem with static arrays" begin
103+
function f1(u, p)
104+
y = u[1]
105+
x = p[1]
106+
return SA[1 - y^2 - x^2]
107+
end
108+
109+
function f2(u, p)
110+
yt = u[1]
111+
x, xt, y = p
112+
return SA[-2y * yt - 2x * xt]
113+
end
114+
115+
function f3(u, p)
116+
lam = u[1]
117+
x, xt, y, yt = p
118+
return SA[-2xt^2 - 2yt^2 - 2y * (-1 + y * lam) - 2x^2 * lam]
119+
end
120+
121+
explicit1 = Returns(nothing)
122+
function explicit2(p, sols)
123+
p[3] = sols[1].u[1]
124+
end
125+
function explicit3(p, sols)
126+
p[4] = sols[2].u[1]
127+
end
128+
129+
p = [1.0, 0.0, NaN, NaN]
130+
prob1 = NonlinearProblem(f1, SA[1.0], p)
131+
prob2 = NonlinearProblem(f2, SA[0.0], p)
132+
prob3 = NonlinearProblem(f3, SA[1.0], p)
133+
sccprob = SCCNonlinearProblem(
134+
(prob1, prob2, prob3), (explicit1, explicit2, explicit3), p, true)
135+
136+
@test !SciMLBase.isinplace(sccprob)
137+
@test sccprob isa SCCNonlinearProblem{SVector{3, Float64}}
138+
@test state_values(sccprob) isa SVector{3, Float64}
139+
@test sccprob.p === prob1.p === prob2.p === prob3.p
140+
141+
sccprob2 = @inferred remake(sccprob; u0 = SA[2.0, 1.0, 2.0])
142+
@test !SciMLBase.isinplace(sccprob2)
143+
@test sccprob2 isa SCCNonlinearProblem{SVector{3, Float64}}
144+
@test state_values(sccprob2) isa SVector{3, Float64}
145+
end

0 commit comments

Comments
 (0)