Skip to content

Commit eccc470

Browse files
feat: allow update_initializeprob! to be out-of-place in OverrideInitData
1 parent a2c205b commit eccc470

File tree

2 files changed

+34
-5
lines changed

2 files changed

+34
-5
lines changed

src/initialization.jl

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,19 +34,26 @@ struct OverrideInitData{IProb, UIProb, IProbMap, IProbPmap, M}
3434
Additional metadata required by the creator of the initialization.
3535
"""
3636
metadata::M
37+
"""
38+
If this flag is `Val{true}`, `update_initializeprob!` is treated as an out-of-place
39+
function which returns the updated `initializeprob`.
40+
"""
41+
is_update_oop::Union{Type{Val{true}}, Type{Val{false}}}
3742

3843
function OverrideInitData(initprob::I, update_initprob!::J, initprobmap::K,
39-
initprobpmap::L, metadata::M) where {I, J, K, L, M}
44+
initprobpmap::L, metadata::M, is_update_oop) where {I, J, K, L, M}
4045
@assert initprob isa
4146
Union{SCCNonlinearProblem, NonlinearProblem, NonlinearLeastSquaresProblem}
4247
return new{I, J, K, L, M}(
43-
initprob, update_initprob!, initprobmap, initprobpmap, metadata)
48+
initprob, update_initprob!, initprobmap, initprobpmap, metadata, is_update_oop)
4449
end
4550
end
4651

4752
function OverrideInitData(
48-
initprob, update_initprob!, initprobmap, initprobpmap; metadata = nothing)
49-
OverrideInitData(initprob, update_initprob!, initprobmap, initprobpmap, metadata)
53+
initprob, update_initprob!, initprobmap, initprobpmap;
54+
metadata = nothing, is_update_oop = Val{false})
55+
OverrideInitData(
56+
initprob, update_initprob!, initprobmap, initprobpmap, metadata, is_update_oop)
5057
end
5158

5259
"""
@@ -244,7 +251,11 @@ function get_initial_values(prob, valp, f, alg::OverrideInit,
244251
initprob = initdata.initializeprob
245252

246253
if initdata.update_initializeprob! !== nothing
247-
initdata.update_initializeprob!(initprob, valp)
254+
if initdata.is_update_oop == Val{true}
255+
initprob = initdata.update_initializeprob!(initprob, valp)
256+
else
257+
initdata.update_initializeprob!(initprob, valp)
258+
end
248259
end
249260

250261
if is_trivial_initialization(initdata)

test/initialization.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,24 @@ end
240240
@test success
241241
end
242242

243+
@testset "`is_update_oop` flag" begin
244+
initprob = remake(initprob; u0 = ones(2), p = ones(1))
245+
update_initializeprob = function (initprob, valp)
246+
return remake(initprob; p = [valp.u[1]])
247+
end
248+
initdata = SciMLBase.OverrideInitData(initprob, update_initializeprob, initprobmap,
249+
initprobpmap; is_update_oop = Val{true})
250+
fn = ODEFunction(rhs2; initialization_data = initdata)
251+
prob = ODEProblem(fn, [2.0, 0.0], (0.0, 1.0), 0.0)
252+
integ = init(prob; initializealg = NoInit())
253+
u0, p, success = SciMLBase.get_initial_values(
254+
prob, integ, fn, SciMLBase.OverrideInit(), Val(false);
255+
nlsolve_alg = NewtonRaphson(), abstol, reltol)
256+
@test u0 [2.0, 2.0]
257+
@test p 1.0
258+
@test success
259+
end
260+
243261
@testset "Solves without `initializeprobmap`" begin
244262
initdata = SciMLBase.@set initialization_data.initializeprobmap = nothing
245263
fn = ODEFunction(rhs2; initialization_data = initdata)

0 commit comments

Comments
 (0)