Skip to content

Commit 258063b

Browse files
Merge pull request #857 from AayushSabharwal/as/get-nlsolve
feat: add fields to `OverrideInit`, better `nlsolve_alg` handling
2 parents 169d419 + 7f93fb2 commit 258063b

File tree

8 files changed

+241
-62
lines changed

8 files changed

+241
-62
lines changed

.github/workflows/Downstream.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,4 +85,4 @@ jobs:
8585
with:
8686
file: lcov.info
8787
token: ${{ secrets.CODECOV_TOKEN }}
88-
fail_ci_if_error: true
88+
fail_ci_if_error: false

src/ODE_nlsolve.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,4 +43,3 @@ struct ODE_NLProbData{NLProb, UNLProb, NLProbMap, NLProbPmap}
4343
"""
4444
nlprobpmap::NLProbPmap
4545
end
46-

src/SciMLBase.jl

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ import CommonSolve: solve, init, step!, solve!
2121
import FunctionWrappersWrappers
2222
import RuntimeGeneratedFunctions
2323
import EnumX
24-
import ADTypes: AbstractADType
24+
import ADTypes: ADTypes, AbstractADType
2525
import Accessors: @set, @reset
2626
using Expronicon.ADT: @match
2727

@@ -351,7 +351,16 @@ struct CheckInit <: DAEInitializationAlgorithm end
351351
"""
352352
$(TYPEDEF)
353353
"""
354-
struct OverrideInit <: DAEInitializationAlgorithm end
354+
struct OverrideInit{T1, T2, F} <: DAEInitializationAlgorithm
355+
abstol::T1
356+
reltol::T2
357+
nlsolve::F
358+
end
359+
360+
function OverrideInit(; abstol = nothing, reltol = nothing, nlsolve = nothing)
361+
OverrideInit(abstol, reltol, nlsolve)
362+
end
363+
OverrideInit(abstol) = OverrideInit(; abstol = abstol, nlsolve = nothing)
355364

356365
# PDE Discretizations
357366

src/initialization.jl

Lines changed: 64 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -68,17 +68,26 @@ function Base.showerror(io::IO, e::OverrideInitMissingAlgorithm)
6868
"OverrideInit specified but no NonlinearSolve.jl algorithm provided. Provide an algorithm via the `nlsolve_alg` keyword argument to `get_initial_values`.")
6969
end
7070

71+
struct OverrideInitNoTolerance <: Exception
72+
tolerance::Symbol
73+
end
74+
75+
function Base.showerror(io::IO, e::OverrideInitNoTolerance)
76+
print(io,
77+
"Tolerances were not provided to `OverrideInit`. `$(e.tolerance)` must be provided as a keyword argument to `get_initial_values` or as a keyword argument to the `OverrideInit` constructor.")
78+
end
79+
7180
"""
72-
Utility function to evaluate the RHS of the ODE, using the integrator's `tmp_cache` if
81+
Utility function to evaluate the RHS, using the integrator's `tmp_cache` if
7382
it is in-place or simply calling the function if not.
7483
"""
75-
function _evaluate_f_ode(integrator, f, isinplace::Val{true}, args...)
84+
function _evaluate_f(integrator, f, isinplace::Val{true}, args...)
7685
tmp = first(get_tmp_cache(integrator))
7786
f(tmp, args...)
7887
return tmp
7988
end
8089

81-
function _evaluate_f_ode(integrator, f, isinplace::Val{false}, args...)
90+
function _evaluate_f(integrator, f, isinplace::Val{false}, args...)
8291
return f(args...)
8392
end
8493

@@ -98,53 +107,49 @@ _vec(v::AbstractVector) = v
98107
99108
Check if the algebraic constraints are satisfied, and error if they aren't. Returns
100109
the `u0` and `p` as-is, and is always successful if it returns. Valid only for
101-
`ODEProblem` and `DAEProblem`. Requires a `DEIntegrator` as its second argument.
110+
`AbstractDEProblem` and `AbstractDAEProblem`. Requires a `DEIntegrator` as its second argument.
111+
112+
Keyword arguments:
113+
- `abstol`: The absolute value below which the norm of the residual of algebraic equations
114+
should lie. The norm function used is `integrator.opts.internalnorm` if present, and
115+
`LinearAlgebra.norm` if not.
102116
"""
103-
function get_initial_values(prob::AbstractODEProblem, integrator, f, alg::CheckInit,
104-
isinplace::Union{Val{true}, Val{false}}; kwargs...)
117+
function get_initial_values(
118+
prob::AbstractDEProblem, integrator::DEIntegrator, f, alg::CheckInit,
119+
isinplace::Union{Val{true}, Val{false}}; abstol, kwargs...)
105120
u0 = state_values(integrator)
106121
p = parameter_values(integrator)
107122
t = current_time(integrator)
108123
M = f.mass_matrix
109124

110125
algebraic_vars = [all(iszero, x) for x in eachcol(M)]
111126
algebraic_eqs = [all(iszero, x) for x in eachrow(M)]
112-
(iszero(algebraic_vars) || iszero(algebraic_eqs)) && return
127+
(iszero(algebraic_vars) || iszero(algebraic_eqs)) && return u0, p, true
113128
update_coefficients!(M, u0, p, t)
114-
tmp = _evaluate_f_ode(integrator, f, isinplace, u0, p, t)
129+
tmp = _evaluate_f(integrator, f, isinplace, u0, p, t)
115130
tmp .= ArrayInterface.restructure(tmp, algebraic_eqs .* _vec(tmp))
116131

117-
normresid = integrator.opts.internalnorm(tmp, t)
118-
if normresid > integrator.opts.abstol
119-
throw(CheckInitFailureError(normresid, integrator.opts.abstol))
132+
normresid = isdefined(integrator.opts, :internalnorm) ?
133+
integrator.opts.internalnorm(tmp, t) : norm(tmp)
134+
if normresid > abstol
135+
throw(CheckInitFailureError(normresid, abstol))
120136
end
121137
return u0, p, true
122138
end
123139

124-
"""
125-
Utility function to evaluate the RHS of the DAE, using the integrator's `tmp_cache` if
126-
it is in-place or simply calling the function if not.
127-
"""
128-
function _evaluate_f_dae(integrator, f, isinplace::Val{true}, args...)
129-
tmp = get_tmp_cache(integrator)[2]
130-
f(tmp, args...)
131-
return tmp
132-
end
133-
134-
function _evaluate_f_dae(integrator, f, isinplace::Val{false}, args...)
135-
return f(args...)
136-
end
137-
138-
function get_initial_values(prob::AbstractDAEProblem, integrator, f, alg::CheckInit,
139-
isinplace::Union{Val{true}, Val{false}}; kwargs...)
140+
function get_initial_values(
141+
prob::AbstractDAEProblem, integrator::DEIntegrator, f, alg::CheckInit,
142+
isinplace::Union{Val{true}, Val{false}}; abstol, kwargs...)
140143
u0 = state_values(integrator)
141144
p = parameter_values(integrator)
142145
t = current_time(integrator)
143146

144-
resid = _evaluate_f_dae(integrator, f, isinplace, integrator.du, u0, p, t)
145-
normresid = integrator.opts.internalnorm(resid, t)
146-
if normresid > integrator.opts.abstol
147-
throw(CheckInitFailureError(normresid, integrator.opts.abstol))
147+
resid = _evaluate_f(integrator, f, isinplace, integrator.du, u0, p, t)
148+
normresid = isdefined(integrator.opts, :internalnorm) ?
149+
integrator.opts.internalnorm(resid, t) : norm(resid)
150+
151+
if normresid > abstol
152+
throw(CheckInitFailureError(normresid, abstol))
148153
end
149154
return u0, p, true
150155
end
@@ -155,12 +160,19 @@ end
155160
Solve a `NonlinearProblem`/`NonlinearLeastSquaresProblem` to obtain the initial `u0` and
156161
`p`. Requires that `f` have the field `initialization_data` which is an `OverrideInitData`.
157162
If the field is absent or the value is `nothing`, return `u0` and `p` successfully as-is.
158-
The NonlinearSolve.jl algorithm to use must be specified through the `nlsolve_alg` keyword
159-
argument, failing which this function will throw an error. The success value returned
160-
depends on the success of the nonlinear solve.
163+
164+
The success value returned depends on the success of the nonlinear solve.
165+
166+
Keyword arguments:
167+
- `nlsolve_alg`: The NonlinearSolve.jl algorithm to use. If not provided, this function will
168+
throw an error.
169+
- `abstol`, `reltol`: The `abstol` (`reltol`) to use for the nonlinear solve. The value
170+
provided to the `OverrideInit` constructor takes priority over this keyword argument.
171+
If the former is `nothing`, this keyword argument will be used. If it is also not provided,
172+
an error will be thrown.
161173
"""
162174
function get_initial_values(prob, valp, f, alg::OverrideInit,
163-
isinplace::Union{Val{true}, Val{false}}; nlsolve_alg = nothing, kwargs...)
175+
iip::Union{Val{true}, Val{false}}; nlsolve_alg = nothing, abstol = nothing, reltol = nothing, kwargs...)
164176
u0 = state_values(valp)
165177
p = parameter_values(valp)
166178

@@ -171,15 +183,30 @@ function get_initial_values(prob, valp, f, alg::OverrideInit,
171183
initdata::OverrideInitData = f.initialization_data
172184
initprob = initdata.initializeprob
173185

174-
if nlsolve_alg === nothing
186+
nlsolve_alg = something(nlsolve_alg, alg.nlsolve, Some(nothing))
187+
if nlsolve_alg === nothing && state_values(initprob) !== nothing
175188
throw(OverrideInitMissingAlgorithm())
176189
end
177190

178191
if initdata.update_initializeprob! !== nothing
179192
initdata.update_initializeprob!(initprob, valp)
180193
end
181194

182-
nlsol = solve(initprob, nlsolve_alg)
195+
if alg.abstol !== nothing
196+
_abstol = alg.abstol
197+
elseif abstol !== nothing
198+
_abstol = abstol
199+
else
200+
throw(OverrideInitNoTolerance(:abstol))
201+
end
202+
if alg.reltol !== nothing
203+
_reltol = alg.reltol
204+
elseif reltol !== nothing
205+
_reltol = reltol
206+
else
207+
throw(OverrideInitNoTolerance(:reltol))
208+
end
209+
nlsol = solve(initprob, nlsolve_alg; abstol = _abstol, reltol = _reltol)
183210

184211
u0 = initdata.initializeprobmap(nlsol)
185212
if initdata.initializeprobpmap !== nothing

src/scimlfunctions.jl

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -401,7 +401,8 @@ numerically-defined functions.
401401
"""
402402
struct ODEFunction{iip, specialize, F, TMM, Ta, Tt, TJ, JVP, VJP, JP, SP, TW, TWt, WP, TPJ,
403403
O, TCV,
404-
SYS, ID<:Union{Nothing, OverrideInitData}, NLP<:Union{Nothing, ODE_NLProbData}} <: AbstractODEFunction{iip}
404+
SYS, ID <: Union{Nothing, OverrideInitData}, NLP <: Union{Nothing, ODE_NLProbData}} <:
405+
AbstractODEFunction{iip}
405406
f::F
406407
mass_matrix::TMM
407408
analytic::Ta
@@ -522,7 +523,8 @@ information on generating the SplitFunction from this symbolic engine.
522523
"""
523524
struct SplitFunction{
524525
iip, specialize, F1, F2, TMM, C, Ta, Tt, TJ, JVP, VJP, JP, WP, SP, TW, TWt,
525-
TPJ, O, TCV, SYS, ID<:Union{Nothing, OverrideInitData}, NLP<:Union{Nothing, ODE_NLProbData}} <: AbstractODEFunction{iip}
526+
TPJ, O, TCV, SYS, ID <: Union{Nothing, OverrideInitData},
527+
NLP <: Union{Nothing, ODE_NLProbData}} <: AbstractODEFunction{iip}
526528
f1::F1
527529
f2::F2
528530
mass_matrix::TMM
@@ -2442,7 +2444,7 @@ function ODEFunction{iip, specialize}(f;
24422444
initializeprobpmap = __has_initializeprobpmap(f) ? f.initializeprobpmap : nothing,
24432445
initialization_data = __has_initialization_data(f) ? f.initialization_data :
24442446
nothing,
2445-
nlprob_data = __has_nlprob_data(f) ? f.nlprob_data : nothing,
2447+
nlprob_data = __has_nlprob_data(f) ? f.nlprob_data : nothing
24462448
) where {iip,
24472449
specialize
24482450
}
@@ -2500,7 +2502,8 @@ function ODEFunction{iip, specialize}(f;
25002502
typeof(sparsity), Any, Any, typeof(W_prototype), Any,
25012503
Any,
25022504
typeof(_colorvec),
2503-
typeof(sys), Union{Nothing, OverrideInitData}, Union{Nothing, ODE_NLProbData}}(_f, mass_matrix, analytic, tgrad, jac,
2505+
typeof(sys), Union{Nothing, OverrideInitData}, Union{Nothing, ODE_NLProbData}}(
2506+
_f, mass_matrix, analytic, tgrad, jac,
25042507
jvp, vjp, jac_prototype, sparsity, Wfact,
25052508
Wfact_t, W_prototype, paramjac,
25062509
observed, _colorvec, sys, initdata, nlprob_data)
@@ -2770,7 +2773,8 @@ function SplitFunction{iip, specialize}(f1, f2;
27702773
if specialize === NoSpecialize
27712774
SplitFunction{iip, specialize, Any, Any, Any, Any, Any, Any, Any, Any, Any,
27722775
Any, Any, Any, Any, Any, Any, Any,
2773-
Any, Any, Union{Nothing, OverrideInitData}, Union{Nothing, ODE_NLProbData}}(f1, f2, mass_matrix, _func_cache,
2776+
Any, Any, Union{Nothing, OverrideInitData}, Union{Nothing, ODE_NLProbData}}(
2777+
f1, f2, mass_matrix, _func_cache,
27742778
analytic,
27752779
tgrad, jac, jvp, vjp, jac_prototype, W_prototype,
27762780
sparsity, Wfact, Wfact_t, paramjac,

test/downstream/initialization.jl

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
using OrdinaryDiffEq, Sundials, SciMLBase, Test
2+
3+
@testset "CheckInit" begin
4+
abstol = 1e-10
5+
@testset "Sundials + ODEProblem" begin
6+
function rhs(u, p, t)
7+
return [u[1] * t, u[1]^2 - u[2]^2]
8+
end
9+
function rhs!(du, u, p, t)
10+
du[1] = u[1] * t
11+
du[2] = u[1]^2 - u[2]^2
12+
end
13+
14+
oopfn = ODEFunction{false}(rhs, mass_matrix = [1 0; 0 0])
15+
iipfn = ODEFunction{true}(rhs!, mass_matrix = [1 0; 0 0])
16+
17+
@testset "Inplace = $(SciMLBase.isinplace(f))" for f in [oopfn, iipfn]
18+
prob = ODEProblem(f, [1.0, 1.0], (0.0, 1.0))
19+
integ = init(prob, Sundials.ARKODE())
20+
u0, _, success = SciMLBase.get_initial_values(
21+
prob, integ, f, SciMLBase.CheckInit(), Val(SciMLBase.isinplace(f)); abstol)
22+
@test success
23+
@test u0 == prob.u0
24+
25+
integ.u[2] = 2.0
26+
@test_throws SciMLBase.CheckInitFailureError SciMLBase.get_initial_values(
27+
prob, integ, f, SciMLBase.CheckInit(), Val(SciMLBase.isinplace(f)); abstol)
28+
end
29+
end
30+
31+
@testset "Sundials + DAEProblem" begin
32+
function daerhs(du, u, p, t)
33+
return [du[1] - u[1] * t - p, u[1]^2 - u[2]^2]
34+
end
35+
function daerhs!(resid, du, u, p, t)
36+
resid[1] = du[1] - u[1] * t - p
37+
resid[2] = u[1]^2 - u[2]^2
38+
end
39+
40+
oopfn = DAEFunction{false}(daerhs)
41+
iipfn = DAEFunction{true}(daerhs!)
42+
43+
@testset "Inplace = $(SciMLBase.isinplace(f))" for f in [oopfn, iipfn]
44+
prob = DAEProblem(f, [1.0, 0.0], [1.0, 1.0], (0.0, 1.0), 1.0)
45+
integ = init(prob, Sundials.IDA())
46+
u0, _, success = SciMLBase.get_initial_values(
47+
prob, integ, f, SciMLBase.CheckInit(), Val(SciMLBase.isinplace(f)); abstol)
48+
@test success
49+
@test u0 == prob.u0
50+
51+
integ.u[2] = 2.0
52+
@test_throws SciMLBase.CheckInitFailureError SciMLBase.get_initial_values(
53+
prob, integ, f, SciMLBase.CheckInit(), Val(SciMLBase.isinplace(f)); abstol)
54+
55+
integ.u[2] = 1.0
56+
integ.du[1] = 2.0
57+
@test_throws SciMLBase.CheckInitFailureError SciMLBase.get_initial_values(
58+
prob, integ, f, SciMLBase.CheckInit(), Val(SciMLBase.isinplace(f)); abstol)
59+
end
60+
end
61+
end

0 commit comments

Comments
 (0)