Skip to content

Commit da8d130

Browse files
Merge pull request #2640 from vyudu/rename
feat: allow solving NLLS in solve, rename SimpleIDSolve to IDSolve
2 parents 532231b + 2341d3c commit da8d130

File tree

6 files changed

+85
-33
lines changed

6 files changed

+85
-33
lines changed

lib/ImplicitDiscreteSolve/src/ImplicitDiscreteSolve.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,16 +10,16 @@ using Reexport
1010
@reexport using DiffEqBase
1111

1212
"""
13-
SimpleIDSolve()
13+
IDSolve()
1414
1515
Simple solver for `ImplicitDiscreteSystems`. Uses `SimpleNewtonRaphson` to solve for the next state at every timestep.
1616
"""
17-
struct SimpleIDSolve <: OrdinaryDiffEqAlgorithm end
17+
struct IDSolve <: OrdinaryDiffEqAlgorithm end
1818

1919
include("cache.jl")
2020
include("solve.jl")
2121
include("alg_utils.jl")
2222

23-
export SimpleIDSolve
23+
export IDSolve
2424

2525
end
Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,19 @@
1-
function SciMLBase.isautodifferentiable(alg::SimpleIDSolve)
1+
function SciMLBase.isautodifferentiable(alg::IDSolve)
22
true
33
end
4-
function SciMLBase.allows_arbitrary_number_types(alg::SimpleIDSolve)
4+
function SciMLBase.allows_arbitrary_number_types(alg::IDSolve)
55
true
66
end
7-
function SciMLBase.allowscomplex(alg::SimpleIDSolve)
7+
function SciMLBase.allowscomplex(alg::IDSolve)
88
true
99
end
1010

11-
SciMLBase.isdiscrete(alg::SimpleIDSolve) = true
11+
SciMLBase.isdiscrete(alg::IDSolve) = true
1212

13-
isfsal(alg::SimpleIDSolve) = false
14-
alg_order(alg::SimpleIDSolve) = 0
15-
beta2_default(alg::SimpleIDSolve) = 0
16-
beta1_default(alg::SimpleIDSolve, beta2) = 0
13+
isfsal(alg::IDSolve) = false
14+
alg_order(alg::IDSolve) = 0
15+
beta2_default(alg::IDSolve) = 0
16+
beta1_default(alg::IDSolve, beta2) = 0
1717

18-
dt_required(alg::SimpleIDSolve) = false
19-
isdiscretealg(alg::SimpleIDSolve) = true
18+
dt_required(alg::IDSolve) = false
19+
isdiscretealg(alg::IDSolve) = true

lib/ImplicitDiscreteSolve/src/cache.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,35 +4,35 @@ mutable struct ImplicitDiscreteState{uType, pType, tType}
44
t_next::tType
55
end
66

7-
mutable struct SimpleIDSolveCache{uType} <: OrdinaryDiffEqMutableCache
7+
mutable struct IDSolveCache{uType} <: OrdinaryDiffEqMutableCache
88
u::uType
99
uprev::uType
1010
state::ImplicitDiscreteState
1111
prob::Union{Nothing, SciMLBase.AbstractNonlinearProblem}
1212
end
1313

14-
function alg_cache(alg::SimpleIDSolve, u, rate_prototype, ::Type{uEltypeNoUnits},
14+
function alg_cache(alg::IDSolve, u, rate_prototype, ::Type{uEltypeNoUnits},
1515
::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t,
1616
dt, reltol, p, calck,
1717
::Val{true}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits}
1818

1919
state = ImplicitDiscreteState(isnothing(u) ? nothing : zero(u), p, t)
20-
SimpleIDSolveCache(u, uprev, state, nothing)
20+
IDSolveCache(u, uprev, state, nothing)
2121
end
2222

23-
isdiscretecache(cache::SimpleIDSolveCache) = true
23+
isdiscretecache(cache::IDSolveCache) = true
2424

25-
struct SimpleIDSolveConstantCache <: OrdinaryDiffEqConstantCache
25+
struct IDSolveConstantCache <: OrdinaryDiffEqConstantCache
2626
prob::Union{Nothing, SciMLBase.AbstractNonlinearProblem}
2727
end
2828

29-
function alg_cache(alg::SimpleIDSolve, u, rate_prototype, ::Type{uEltypeNoUnits},
29+
function alg_cache(alg::IDSolve, u, rate_prototype, ::Type{uEltypeNoUnits},
3030
::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t,
3131
dt, reltol, p, calck,
3232
::Val{false}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits}
3333

3434
state = ImplicitDiscreteState(isnothing(u) ? nothing : zero(u), p, t)
35-
SimpleIDSolveCache(u, uprev, state, nothing)
35+
IDSolveCache(u, uprev, state, nothing)
3636
end
3737

38-
get_fsalfirstlast(cache::SimpleIDSolveCache, rate_prototype) = (nothing, nothing)
38+
get_fsalfirstlast(cache::IDSolveCache, rate_prototype) = (nothing, nothing)

lib/ImplicitDiscreteSolve/src/solve.jl

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# Remake the nonlinear problem, then update
2-
function perform_step!(integrator, cache::SimpleIDSolveCache, repeat_step = false)
2+
function perform_step!(integrator, cache::IDSolveCache, repeat_step = false)
33
@unpack alg, u, uprev, dt, t, f, p = integrator
44
@unpack state, prob = cache
55
state.u .= uprev
@@ -11,7 +11,7 @@ function perform_step!(integrator, cache::SimpleIDSolveCache, repeat_step = fals
1111
integrator.u = u
1212
end
1313

14-
function initialize!(integrator, cache::SimpleIDSolveCache)
14+
function initialize!(integrator, cache::IDSolveCache)
1515
integrator.u isa AbstractVector && (cache.state.u .= integrator.u)
1616
cache.state.p = integrator.p
1717
cache.state.t_next = integrator.t
@@ -22,11 +22,13 @@ function initialize!(integrator, cache::SimpleIDSolveCache)
2222
else
2323
(u_next, p) -> f(u_next, p.u, p.p, p.t_next)
2424
end
25+
u_len = isnothing(integrator.u) ? 0 : length(integrator.u)
26+
nlls = !isnothing(f.resid_prototype) && (length(f.resid_prototype) != u_len)
2527

26-
prob = if isinplace(f)
27-
NonlinearProblem{true}(_f, cache.state.u, cache.state)
28+
prob = if nlls
29+
NonlinearLeastSquaresProblem{isinplace(f)}(NonlinearFunction(_f; resid_prototype = f.resid_prototype), cache.state.u, cache.state)
2830
else
29-
NonlinearProblem{false}(_f, cache.state.u, cache.state)
31+
NonlinearProblem{isinplace(f)}(_f, cache.state.u, cache.state)
3032
end
3133
cache.prob = prob
3234
end
@@ -47,8 +49,18 @@ function _initialize_dae!(integrator, prob::ImplicitDiscreteProblem,
4749
else
4850
(u_next, p) -> f(u_next, p.u, p.p, p.t_next)
4951
end
50-
prob = NonlinearProblem{isinplace(f)}(_f, u, initstate)
52+
53+
nlls = !isnothing(f.resid_prototype) && (length(f.resid_prototype) != length(integrator.u))
54+
prob = if nlls
55+
NonlinearLeastSquaresProblem{isinplace(f)}(NonlinearFunction(_f; resid_prototype = f.resid_prototype), u, initstate)
56+
else
57+
NonlinearProblem{isinplace(f)}(_f, u, initstate)
58+
end
5159
sol = solve(prob, SimpleNewtonRaphson())
52-
integrator.u = sol
60+
if sol.retcode == ReturnCode.Success
61+
integrator.u = sol
62+
else
63+
integrator.sol = SciMLBase.solution_new_retcode(integrator.sol, ReturnCode.InitialFailure)
64+
end
5365
end
5466
end

lib/ImplicitDiscreteSolve/test/runtests.jl

Lines changed: 44 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ using Test
33
using ImplicitDiscreteSolve
44
using OrdinaryDiffEqCore
55
using OrdinaryDiffEqSDIRK
6+
using SciMLBase
67

78
# Test implicit Euler using ImplicitDiscreteProblem
89
@testset "Implicit Euler" begin
@@ -20,7 +21,7 @@ using OrdinaryDiffEqSDIRK
2021
tspan = (0., 0.5)
2122

2223
idprob = ImplicitDiscreteProblem(f!, u0, tspan, []; dt = 0.01)
23-
idsol = solve(idprob, SimpleIDSolve())
24+
idsol = solve(idprob, IDSolve())
2425

2526
oprob = ODEProblem(lotkavolterra, u0, tspan)
2627
osol = solve(oprob, ImplicitEuler())
@@ -43,7 +44,7 @@ using OrdinaryDiffEqSDIRK
4344
tspan = (0, 0.2)
4445

4546
idprob = ImplicitDiscreteProblem(g!, u0, tspan, []; dt = 0.01)
46-
idsol = solve(idprob, SimpleIDSolve())
47+
idsol = solve(idprob, IDSolve())
4748

4849
oprob = ODEProblem(ff, u0, tspan)
4950
osol = solve(oprob, ImplicitEuler())
@@ -60,7 +61,7 @@ end
6061
tsteps = 15
6162
u0 = [1., 3.]
6263
idprob = ImplicitDiscreteProblem(periodic!, u0, (0, tsteps), [])
63-
integ = init(idprob, SimpleIDSolve())
64+
integ = init(idprob, IDSolve())
6465
@test integ.u[1]^2 + integ.u[2]^2 16
6566

6667
for ts in 1:tsteps
@@ -77,5 +78,44 @@ end
7778
tsteps = 5
7879
u0 = nothing
7980
idprob = ImplicitDiscreteProblem(empty, u0, (0, tsteps), [])
80-
@test_nowarn integ = init(idprob, SimpleIDSolve())
81+
@test_nowarn integ = init(idprob, IDSolve())
82+
end
83+
84+
@testset "Create NonlinearLeastSquaresProblem" begin
85+
function over(u_next, u, p, t)
86+
[u_next[1] - 1, u_next[2] - 1, u_next[1] - u_next[2]]
87+
end
88+
89+
tsteps = 5
90+
u0 = [1., 1.]
91+
idprob = ImplicitDiscreteProblem(ImplicitDiscreteFunction(over, resid_prototype = zeros(3)), u0, (0, tsteps), [])
92+
integ = init(idprob, IDSolve())
93+
@test integ.cache.prob isa NonlinearLeastSquaresProblem
94+
95+
function under(u_next, u, p, t)
96+
[u_next[1] - u_next[2] - 1]
97+
end
98+
idprob = ImplicitDiscreteProblem(ImplicitDiscreteFunction(under; resid_prototype = zeros(1)), u0, (0, tsteps), [])
99+
integ = init(idprob, IDSolve())
100+
@test integ.cache.prob isa NonlinearLeastSquaresProblem
101+
102+
function full(u_next, u, p, t)
103+
[u_next[1]^2 - 3, u_next[2] - u[1]]
104+
end
105+
idprob = ImplicitDiscreteProblem(ImplicitDiscreteFunction(full; resid_prototype = zeros(2)), u0, (0, tsteps), [])
106+
integ = init(idprob, IDSolve())
107+
@test integ.cache.prob isa NonlinearProblem
108+
end
109+
110+
@testset "InitialFailure thrown" begin
111+
function bad(u_next, u, p, t)
112+
[u_next[1] - u_next[2], u_next[1] - 3, u_next[2] - 4]
113+
end
114+
115+
u0 = [3., 4.]
116+
idprob = ImplicitDiscreteProblem(bad, u0, (0, 0), [])
117+
integ = init(idprob, IDSolve())
118+
@test check_error(integ) == ReturnCode.InitialFailure
119+
sol = solve(idprob, IDSolve())
120+
@test length(sol.u) == 1
81121
end

lib/OrdinaryDiffEqCore/src/solve.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -563,7 +563,7 @@ function DiffEqBase.__init(
563563
if initialize_integrator
564564
if isdae || SciMLBase.has_initializeprob(prob.f) || prob isa SciMLBase.ImplicitDiscreteProblem
565565
DiffEqBase.initialize_dae!(integrator)
566-
update_uprev!(integrator)
566+
!isnothing(integrator.u) && update_uprev!(integrator)
567567
end
568568

569569
if save_start

0 commit comments

Comments
 (0)