Skip to content

Commit be2d756

Browse files
Merge pull request #2466 from SciML/checkinit2
Implement CheckInit
2 parents 35ac8d7 + 540184e commit be2d756

File tree

6 files changed

+136
-5
lines changed

6 files changed

+136
-5
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ PrecompileTools = "1"
133133
Preferences = "1.3"
134134
RecursiveArrayTools = "2.36, 3"
135135
Reexport = "1.0"
136-
SciMLBase = "2.50.4"
136+
SciMLBase = "2.53"
137137
SciMLOperators = "0.3"
138138
SciMLStructures = "1"
139139
SimpleNonlinearSolve = "1"

lib/OrdinaryDiffEqCore/src/OrdinaryDiffEqCore.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ using DiffEqBase: check_error!, @def, _vec, _reshape
5858

5959
using FastBroadcast: @.., True, False
6060

61-
using SciMLBase: NoInit, _unwrap_val
61+
using SciMLBase: NoInit, CheckInit, _unwrap_val
6262

6363
import SciMLBase: alg_order
6464

lib/OrdinaryDiffEqCore/src/initialize_dae.jl

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,3 +176,83 @@ function _initialize_dae!(integrator, prob::Union{ODEProblem, DAEProblem},
176176
ReturnCode.InitialFailure)
177177
end
178178
end
179+
180+
## CheckInit
181+
struct CheckInitFailureError <: Exception
182+
normresid
183+
abstol
184+
end
185+
186+
function Base.showerror(io::IO, e::CheckInitFailureError)
187+
print(io, "CheckInit specified but initialization not satisifed. normresid = $(e.normresid) > abstol = $(e.abstol)")
188+
end
189+
190+
function _initialize_dae!(integrator, prob::ODEProblem, alg::CheckInit,
191+
isinplace::Val{true})
192+
@unpack p, t, f = integrator
193+
M = integrator.f.mass_matrix
194+
tmp = first(get_tmp_cache(integrator))
195+
u0 = integrator.u
196+
197+
algebraic_vars = [all(iszero, x) for x in eachcol(M)]
198+
algebraic_eqs = [all(iszero, x) for x in eachrow(M)]
199+
(iszero(algebraic_vars) || iszero(algebraic_eqs)) && return
200+
update_coefficients!(M, u0, p, t)
201+
f(tmp, u0, p, t)
202+
tmp .= ArrayInterface.restructure(tmp, algebraic_eqs .* _vec(tmp))
203+
204+
normresid = integrator.opts.internalnorm(tmp, t)
205+
if normresid > integrator.opts.abstol
206+
throw(CheckInitFailureError(normresid, integrator.opts.abstol))
207+
end
208+
end
209+
210+
function _initialize_dae!(integrator, prob::ODEProblem, alg::CheckInit,
211+
isinplace::Val{false})
212+
@unpack p, t, f = integrator
213+
u0 = integrator.u
214+
M = integrator.f.mass_matrix
215+
216+
algebraic_vars = [all(iszero, x) for x in eachcol(M)]
217+
algebraic_eqs = [all(iszero, x) for x in eachrow(M)]
218+
(iszero(algebraic_vars) || iszero(algebraic_eqs)) && return
219+
update_coefficients!(M, u0, p, t)
220+
du = f(u0, p, t)
221+
resid = _vec(du)[algebraic_eqs]
222+
223+
normresid = integrator.opts.internalnorm(resid, t)
224+
if normresid > integrator.opts.abstol
225+
throw(CheckInitFailureError(normresid, integrator.opts.abstol))
226+
end
227+
end
228+
229+
function _initialize_dae!(integrator, prob::DAEProblem,
230+
alg::CheckInit, isinplace::Val{true})
231+
@unpack p, t, f = integrator
232+
u0 = integrator.u
233+
resid = get_tmp_cache(integrator)[2]
234+
235+
f(resid, integrator.du, u0, p, t)
236+
normresid = integrator.opts.internalnorm(resid, t)
237+
if normresid > integrator.opts.abstol
238+
throw(CheckInitFailureError(normresid, integrator.opts.abstol))
239+
end
240+
end
241+
242+
function _initialize_dae!(integrator, prob::DAEProblem,
243+
alg::CheckInit, isinplace::Val{false})
244+
@unpack p, t, f = integrator
245+
u0 = integrator.u
246+
247+
nlequation_oop = u -> begin
248+
f((u - u0) / dt, u, p, t)
249+
end
250+
251+
nlequation = (u, _) -> nlequation_oop(u)
252+
253+
resid = f(integrator.du, u0, p, t)
254+
normresid = integrator.opts.internalnorm(resid, t)
255+
if normresid > integrator.opts.abstol
256+
throw(CheckInitFailureError(normresid, integrator.opts.abstol))
257+
end
258+
end

lib/OrdinaryDiffEqNonlinearSolve/src/OrdinaryDiffEqNonlinearSolve.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,8 @@ end
3434
using OrdinaryDiffEqCore: resize_nlsolver!, _initialize_dae!,
3535
AbstractNLSolverAlgorithm, AbstractNLSolverCache,
3636
AbstractNLSolver, NewtonAlgorithm, @unpack,
37-
OverrideInit, ShampineCollocationInit, BrownFullBasicInit, _vec,
38-
_unwrap_val, DAEAlgorithm,
37+
OverrideInit, ShampineCollocationInit, BrownFullBasicInit,
38+
_vec, _unwrap_val, DAEAlgorithm,
3939
_reshape, calculate_residuals, calculate_residuals!,
4040
has_special_newton_error, isadaptive,
4141
TryAgain, DIRK, COEFFICIENT_MULTISTEP, NORDSIECK_MULTISTEP, GLM,

test/interface/checkinit_tests.jl

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
using OrdinaryDiffEqBDF, OrdinaryDiffEqRosenbrock, LinearAlgebra, ForwardDiff, Test
2+
using OrdinaryDiffEqCore
3+
4+
function rober(du, u, p, t)
5+
y₁, y₂, y₃ = u
6+
k₁, k₂, k₃ = p
7+
du[1] = -k₁ * y₁ + k₃ * y₂ * y₃
8+
du[2] = k₁ * y₁ - k₃ * y₂ * y₃ - k₂ * y₂^2
9+
du[3] = y₁ + y₂ + y₃ - 1
10+
nothing
11+
end
12+
function rober(u, p, t)
13+
y₁, y₂, y₃ = u
14+
k₁, k₂, k₃ = p
15+
[-k₁ * y₁ + k₃ * y₂ * y₃,
16+
k₁ * y₁ - k₃ * y₂ * y₃ - k₂ * y₂^2,
17+
y₁ + y₂ + y₃ - 1]
18+
end
19+
M = [1.0 0 0
20+
0 1.0 0
21+
0 0 0]
22+
roberf = ODEFunction(rober, mass_matrix = M)
23+
roberf_oop = ODEFunction{false}(rober, mass_matrix = M)
24+
prob_mm = ODEProblem(roberf, [1.0, 0.0, 0.2], (0.0, 1e5), (0.04, 3e7, 1e4))
25+
prob_mm_oop = ODEProblem(roberf_oop, [1.0, 0.0, 0.2], (0.0, 1e5), (0.04, 3e7, 1e4))
26+
27+
@test_throws OrdinaryDiffEqCore.CheckInitFailureError solve(prob_mm, Rodas5P(), reltol = 1e-8, abstol = 1e-8, initializealg = SciMLBase.CheckInit())
28+
@test_throws OrdinaryDiffEqCore.CheckInitFailureError solve(prob_mm_oop, Rodas5P(), reltol = 1e-8, abstol = 1e-8, initializealg = SciMLBase.CheckInit())
29+
30+
f_oop = function (du, u, p, t)
31+
out1 = -0.04u[1] + 1e4 * u[2] * u[3] - du[1]
32+
out2 = +0.04u[1] - 3e7 * u[2]^2 - 1e4 * u[2] * u[3] - du[2]
33+
out3 = u[1] + u[2] + u[3] - 1.0
34+
[out1, out2, out3]
35+
end
36+
37+
f = function (resid, du, u, p, t)
38+
resid[1] = -0.04u[1] + 1e4 * u[2] * u[3] - du[1]
39+
resid[2] = +0.04u[1] - 3e7 * u[2]^2 - 1e4 * u[2] * u[3] - du[2]
40+
resid[3] = u[1] + u[2] + u[3] - 1.0
41+
end
42+
43+
u₀ = [1.0, 0, 0.2]
44+
du₀ = [0.0, 0.0, 0.0]
45+
tspan = (0.0, 100000.0)
46+
differential_vars = [true, true, false]
47+
prob = DAEProblem(f, du₀, u₀, tspan, differential_vars = differential_vars)
48+
prob_oop = DAEProblem(f_oop, du₀, u₀, tspan, differential_vars = differential_vars)
49+
@test_throws OrdinaryDiffEqCore.CheckInitFailureError solve(prob, DFBDF(), reltol = 1e-8, abstol = 1e-8, initializealg = SciMLBase.CheckInit())
50+
@test_throws OrdinaryDiffEqCore.CheckInitFailureError solve(prob_oop, DFBDF(), reltol = 1e-8, abstol = 1e-8, initializealg = SciMLBase.CheckInit())

test/runtests.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,8 @@ end
6565
@time @safetestset "Linear Solver Split ODE Tests" include("interface/linear_solver_split_ode_test.jl")
6666
@time @safetestset "Sparse Diff Tests" include("interface/sparsediff_tests.jl")
6767
@time @safetestset "Enum Tests" include("interface/enums.jl")
68-
@time @safetestset "Enum Tests" include("interface/get_du.jl")
68+
@time @safetestset "CheckInit Tests" include("interface/checkinit_tests.jl")
69+
@time @safetestset "Get du Tests" include("interface/get_du.jl")
6970
@time @safetestset "Mass Matrix Tests" include("interface/mass_matrix_tests.jl")
7071
@time @safetestset "W-Operator prototype tests" include("interface/wprototype_tests.jl")
7172
end

0 commit comments

Comments
 (0)