Skip to content

Commit 0c9593b

Browse files
refactor: improve type-stability of OverrideInitData
1 parent 7df67f7 commit 0c9593b

File tree

2 files changed

+8
-7
lines changed

2 files changed

+8
-7
lines changed

src/initialization.jl

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33
44
A collection of all the data required for `OverrideInit`.
55
"""
6-
struct OverrideInitData{IProb, UIProb, IProbMap, IProbPmap, M}
6+
struct OverrideInitData{
7+
IProb, UIProb, IProbMap, IProbPmap, M, OOP <: Union{Val{true}, Val{false}}}
78
"""
89
The `AbstractNonlinearProblem` to solve for initialization.
910
"""
@@ -38,20 +39,20 @@ struct OverrideInitData{IProb, UIProb, IProbMap, IProbPmap, M}
3839
If this flag is `Val{true}`, `update_initializeprob!` is treated as an out-of-place
3940
function which returns the updated `initializeprob`.
4041
"""
41-
is_update_oop::Union{Type{Val{true}}, Type{Val{false}}}
42+
is_update_oop::OOP
4243

4344
function OverrideInitData(initprob::I, update_initprob!::J, initprobmap::K,
44-
initprobpmap::L, metadata::M, is_update_oop) where {I, J, K, L, M}
45+
initprobpmap::L, metadata::M, is_update_oop::O) where {I, J, K, L, M, O}
4546
@assert initprob isa
4647
Union{SCCNonlinearProblem, NonlinearProblem, NonlinearLeastSquaresProblem}
47-
return new{I, J, K, L, M}(
48+
return new{I, J, K, L, M, O}(
4849
initprob, update_initprob!, initprobmap, initprobpmap, metadata, is_update_oop)
4950
end
5051
end
5152

5253
function OverrideInitData(
5354
initprob, update_initprob!, initprobmap, initprobpmap;
54-
metadata = nothing, is_update_oop = Val{false})
55+
metadata = nothing, is_update_oop = Val(false))
5556
OverrideInitData(
5657
initprob, update_initprob!, initprobmap, initprobpmap, metadata, is_update_oop)
5758
end
@@ -251,7 +252,7 @@ function get_initial_values(prob, valp, f, alg::OverrideInit,
251252
initprob = initdata.initializeprob
252253

253254
if initdata.update_initializeprob! !== nothing
254-
if initdata.is_update_oop == Val{true}
255+
if initdata.is_update_oop === Val(true)
255256
initprob = initdata.update_initializeprob!(initprob, valp)
256257
else
257258
initdata.update_initializeprob!(initprob, valp)

test/initialization.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,7 @@ end
246246
return remake(initprob; p = [valp.u[1]])
247247
end
248248
initdata = SciMLBase.OverrideInitData(initprob, update_initializeprob, initprobmap,
249-
initprobpmap; is_update_oop = Val{true})
249+
initprobpmap; is_update_oop = Val(true))
250250
fn = ODEFunction(rhs2; initialization_data = initdata)
251251
prob = ODEProblem(fn, [2.0, 0.0], (0.0, 1.0), 0.0)
252252
integ = init(prob; initializealg = NoInit())

0 commit comments

Comments
 (0)