Skip to content

Commit 9679b8b

Browse files
Better warning on lack of initialization algorithm and slim Rosenbrock
This fixes the overly large Rosenbrock by giving and erroring mechanism to tell users that adding a nonlinear solver is required for DAEs. Before, we simply put it as a dep to Rosenbrock to fix the issues with the split PR with a note that this would be removed shortly. This does the removal and instead throws a contextual error saying you just need to add that library if you want that functionality. This should make Rosenbrocks a lot leaner for ODEs, and this will make cases where DAE intiialization is required on an explicit method give a much better error. In some cases, an error will go away since the trivial case is now contained in Core, meaning that if a nonlinear solver isn't actually required (analytically solved initialization) it will just handle it without erroring whereas before it would require the nonlinear solver libraries (and then not use them).
1 parent 7a98e43 commit 9679b8b

File tree

5 files changed

+89
-59
lines changed

5 files changed

+89
-59
lines changed

lib/OrdinaryDiffEqCore/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "OrdinaryDiffEqCore"
22
uuid = "bbf590c4-e513-4bbe-9b18-05decba2e5d8"
33
authors = ["ParamThakkar123 <paramthakkar864@gmail.com>"]
4-
version = "1.4.1"
4+
version = "1.5.0"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

lib/OrdinaryDiffEqCore/src/initialize_dae.jl

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,8 +104,65 @@ function _initialize_dae!(integrator, prob::DAEProblem,
104104
end
105105
end
106106

107+
## Nonlinear Solver Defaulting
108+
109+
## If an alg is given use it
110+
default_nlsolve(alg, isinplace, u, initprob, autodiff = false) = alg
111+
112+
## If the initialization is trivial just use nothing alg
113+
function default_nlsolve(
114+
::Nothing, isinplace, u::Nothing, ::NonlinearProblem, autodiff = false)
115+
nothing
116+
end
117+
118+
function default_nlsolve(
119+
::Nothing, isinplace, u::Nothing, ::NonlinearLeastSquaresProblem, autodiff = false)
120+
nothing
121+
end
122+
123+
function OrdinaryDiffEqCore.default_nlsolve(::Nothing, isinplace::Val{true}, u, ::NonlinearProblem, autodiff = false)
124+
error("This ODE requires a DAE initialization and thus a nonlinear solve but no nonlinear solve has been loaded. To solve this problem, do `using OrdinaryDiffEqNonlinearSolve` or pass a custom `nlsolve` choice into the `initializealg`.")
125+
end
126+
127+
function OrdinaryDiffEqCore.default_nlsolve(::Nothing, isinplace::Val{true}, u, ::NonlinearLeastSquaresProblem, autodiff = false)
128+
error("This ODE requires a DAE initialization and thus a nonlinear solve but no nonlinear solve has been loaded. To solve this problem, do `using OrdinaryDiffEqNonlinearSolve` or pass a custom `nlsolve` choice into the `initializealg`.")
129+
end
130+
107131
## NoInit
108132

109133
function _initialize_dae!(integrator, prob::Union{ODEProblem, DAEProblem},
110134
alg::NoInit, x::Union{Val{true}, Val{false}})
111135
end
136+
137+
## OverrideInit
138+
139+
function _initialize_dae!(integrator, prob::Union{ODEProblem, DAEProblem},
140+
alg::OverrideInit, isinplace::Union{Val{true}, Val{false}})
141+
initializeprob = prob.f.initializeprob
142+
143+
# If it doesn't have autodiff, assume it comes from symbolic system like ModelingToolkit
144+
# Since then it's the case of not a DAE but has initializeprob
145+
# In which case, it should be differentiable
146+
isAD = if initializeprob.u0 === nothing
147+
AutoForwardDiff
148+
elseif has_autodiff(integrator.alg)
149+
alg_autodiff(integrator.alg) isa AutoForwardDiff
150+
else
151+
true
152+
end
153+
154+
alg = default_nlsolve(alg.nlsolve, isinplace, initializeprob.u0, initializeprob, isAD)
155+
nlsol = solve(initializeprob, alg)
156+
if isinplace === Val{true}()
157+
integrator.u .= prob.f.initializeprobmap(nlsol)
158+
elseif isinplace === Val{false}()
159+
integrator.u = prob.f.initializeprobmap(nlsol)
160+
else
161+
error("Unreachable reached. Report this error.")
162+
end
163+
164+
if nlsol.retcode != ReturnCode.Success
165+
integrator.sol = SciMLBase.solution_new_retcode(integrator.sol,
166+
ReturnCode.InitialFailure)
167+
end
168+
end

lib/OrdinaryDiffEqNonlinearSolve/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "OrdinaryDiffEqNonlinearSolve"
22
uuid = "127b3ac7-2247-4354-8eb6-78cf4e7c58e8"
33
authors = ["Chris Rackauckas <accounts@chrisrackauckas.com>", "Yingbo Ma <mayingbo5@gmail.com>"]
4-
version = "1.1.0"
4+
version = "1.2.0"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

lib/OrdinaryDiffEqNonlinearSolve/src/initialize_dae.jl

Lines changed: 29 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -1,61 +1,36 @@
1-
default_nlsolve(alg, isinplace, u, initprob, autodiff = false) = alg
2-
3-
function default_nlsolve(
4-
::Nothing, isinplace, u::Nothing, ::NonlinearProblem, autodiff = false)
5-
nothing
6-
end
7-
function default_nlsolve(::Nothing, isinplace, u, ::NonlinearProblem, autodiff = false)
8-
FastShortcutNonlinearPolyalg(;
9-
autodiff = autodiff ? AutoForwardDiff() : AutoFiniteDiff())
10-
end
11-
function default_nlsolve(::Nothing, isinplace::Val{false}, u::StaticArray,
1+
if isdefined(OrdinaryDiffEqCore, :default_nlsolve)
2+
function OrdinaryDiffEqCore.default_nlsolve(::Nothing, isinplace::Val{true}, u, ::NonlinearProblem, autodiff = false)
3+
FastShortcutNonlinearPolyalg(;
4+
autodiff = autodiff ? AutoForwardDiff() : AutoFiniteDiff())
5+
end
6+
function OrdinaryDiffEqCore.default_nlsolve(
7+
::Nothing, isinplace::Val{true}, u, ::NonlinearLeastSquaresProblem, autodiff = false)
8+
FastShortcutNLLSPolyalg(; autodiff = autodiff ? AutoForwardDiff() : AutoFiniteDiff())
9+
end
10+
function OrdinaryDiffEqCore.default_nlsolve(::Nothing, isinplace::Val{false}, u::StaticArray,
1211
::NonlinearProblem, autodiff = false)
13-
SimpleTrustRegion(autodiff = autodiff ? AutoForwardDiff() : AutoFiniteDiff())
14-
end
15-
16-
function default_nlsolve(
17-
::Nothing, isinplace, u::Nothing, ::NonlinearLeastSquaresProblem, autodiff = false)
18-
nothing
19-
end
20-
function default_nlsolve(
12+
SimpleTrustRegion(autodiff = autodiff ? AutoForwardDiff() : AutoFiniteDiff())
13+
end
14+
function OrdinaryDiffEqCore.default_nlsolve(::Nothing, isinplace::Val{false}, u::StaticArray,
15+
::NonlinearLeastSquaresProblem, autodiff = false)
16+
SimpleGaussNewton(autodiff = autodiff ? AutoForwardDiff() : AutoFiniteDiff())
17+
end
18+
else
19+
function default_nlsolve(::Nothing, isinplace, u, ::NonlinearProblem, autodiff = false)
20+
FastShortcutNonlinearPolyalg(;
21+
autodiff = autodiff ? AutoForwardDiff() : AutoFiniteDiff())
22+
end
23+
function default_nlsolve(
2124
::Nothing, isinplace, u, ::NonlinearLeastSquaresProblem, autodiff = false)
22-
FastShortcutNLLSPolyalg(; autodiff = autodiff ? AutoForwardDiff() : AutoFiniteDiff())
23-
end
24-
function default_nlsolve(::Nothing, isinplace::Val{false}, u::StaticArray,
25-
::NonlinearLeastSquaresProblem, autodiff = false)
26-
SimpleGaussNewton(autodiff = autodiff ? AutoForwardDiff() : AutoFiniteDiff())
27-
end
28-
29-
## OverrideInit
30-
31-
function _initialize_dae!(integrator, prob::Union{ODEProblem, DAEProblem},
32-
alg::OverrideInit, isinplace::Union{Val{true}, Val{false}})
33-
initializeprob = prob.f.initializeprob
34-
35-
# If it doesn't have autodiff, assume it comes from symbolic system like ModelingToolkit
36-
# Since then it's the case of not a DAE but has initializeprob
37-
# In which case, it should be differentiable
38-
isAD = if initializeprob.u0 === nothing
39-
AutoForwardDiff
40-
elseif has_autodiff(integrator.alg)
41-
alg_autodiff(integrator.alg) isa AutoForwardDiff
42-
else
43-
true
25+
FastShortcutNLLSPolyalg(; autodiff = autodiff ? AutoForwardDiff() : AutoFiniteDiff())
4426
end
45-
46-
alg = default_nlsolve(alg.nlsolve, isinplace, initializeprob.u0, initializeprob, isAD)
47-
nlsol = solve(initializeprob, alg)
48-
if isinplace === Val{true}()
49-
integrator.u .= prob.f.initializeprobmap(nlsol)
50-
elseif isinplace === Val{false}()
51-
integrator.u = prob.f.initializeprobmap(nlsol)
52-
else
53-
error("Unreachable reached. Report this error.")
27+
function default_nlsolve(::Nothing, isinplace::Val{false}, u::StaticArray,
28+
::NonlinearProblem, autodiff = false)
29+
SimpleTrustRegion(autodiff = autodiff ? AutoForwardDiff() : AutoFiniteDiff())
5430
end
55-
56-
if nlsol.retcode != ReturnCode.Success
57-
integrator.sol = SciMLBase.solution_new_retcode(integrator.sol,
58-
ReturnCode.InitialFailure)
31+
function default_nlsolve(::Nothing, isinplace::Val{false}, u::StaticArray,
32+
::NonlinearLeastSquaresProblem, autodiff = false)
33+
SimpleGaussNewton(autodiff = autodiff ? AutoForwardDiff() : AutoFiniteDiff())
5934
end
6035
end
6136

lib/OrdinaryDiffEqRosenbrock/Project.toml

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "OrdinaryDiffEqRosenbrock"
22
uuid = "43230ef6-c299-4910-a778-202eb28ce4ce"
33
authors = ["ParamThakkar123 <paramthakkar864@gmail.com>"]
4-
version = "1.1.1"
4+
version = "1.2.0"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
@@ -15,7 +15,6 @@ MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
1515
MuladdMacro = "46d2c3a1-f734-5fdb-9937-b9b9aeba4221"
1616
OrdinaryDiffEqCore = "bbf590c4-e513-4bbe-9b18-05decba2e5d8"
1717
OrdinaryDiffEqDifferentiation = "4302a76b-040a-498a-8c04-15b101fed76b"
18-
OrdinaryDiffEqNonlinearSolve = "127b3ac7-2247-4354-8eb6-78cf4e7c58e8"
1918
Polyester = "f517fe37-dbe3-4b94-8317-1923a5111588"
2019
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
2120
Preferences = "21216c6a-2e73-6563-6e65-726566657250"
@@ -37,7 +36,6 @@ MuladdMacro = "0.2.4"
3736
ODEProblemLibrary = "0.1.8"
3837
OrdinaryDiffEqCore = "1.1"
3938
OrdinaryDiffEqDifferentiation = "<0.0.1, 1"
40-
OrdinaryDiffEqNonlinearSolve = "<0.0.1, 1"
4139
Polyester = "0.7.16"
4240
PrecompileTools = "1.2.1"
4341
Preferences = "1.4.3"

0 commit comments

Comments
 (0)