Skip to content

Commit e1c0528

Browse files
authored
Merge pull request #390 from SciML/ap/fixes
Misc. Improvements
2 parents 52b3832 + 81edf48 commit e1c0528

15 files changed

+121
-47
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "NonlinearSolve"
22
uuid = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
33
authors = ["SciML"]
4-
version = "3.7.3"
4+
version = "3.8.0"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

src/NonlinearSolve.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@ import PrecompileTools: @recompile_invalidations, @compile_workload, @setup_work
99

1010
@recompile_invalidations begin
1111
using ADTypes, ConcreteStructs, DiffEqBase, FastBroadcast, FastClosures, LazyArrays,
12-
LineSearches, LinearAlgebra, LinearSolve, MaybeInplace, Preferences, Printf,
13-
SciMLBase, SimpleNonlinearSolve, SparseArrays, SparseDiffTools
12+
LinearAlgebra, LinearSolve, MaybeInplace, Preferences, Printf, SciMLBase,
13+
SimpleNonlinearSolve, SparseArrays, SparseDiffTools
1414

1515
import ArrayInterface: undefmatrix, can_setindex, restructure, fast_scalar_indexing
1616
import DiffEqBase: AbstractNonlinearTerminationMode,
@@ -20,6 +20,7 @@ import PrecompileTools: @recompile_invalidations, @compile_workload, @setup_work
2020
import FiniteDiff
2121
import ForwardDiff
2222
import ForwardDiff: Dual
23+
import LineSearches
2324
import LinearSolve: ComposePreconditioner, InvPreconditioner, needs_concrete_A
2425
import RecursiveArrayTools: recursivecopy!, recursivefill!
2526

@@ -29,7 +30,7 @@ import PrecompileTools: @recompile_invalidations, @compile_workload, @setup_work
2930
import StaticArraysCore: StaticArray, SVector, SArray, MArray, Size, SMatrix, MMatrix
3031
end
3132

32-
@reexport using ADTypes, LineSearches, SciMLBase, SimpleNonlinearSolve
33+
@reexport using ADTypes, SciMLBase, SimpleNonlinearSolve
3334

3435
const AbstractSparseADType = Union{ADTypes.AbstractSparseFiniteDifferences,
3536
ADTypes.AbstractSparseForwardMode, ADTypes.AbstractSparseReverseMode}
@@ -157,6 +158,7 @@ export NewtonDescent, SteepestDescent, Dogleg, DampedNewtonDescent, GeodesicAcce
157158
# Globalization
158159
## Line Search Algorithms
159160
export LineSearchesJL, NoLineSearch, RobustNonMonotoneLineSearch, LiFukushimaLineSearch
161+
export Static, HagerZhang, MoreThuente, StrongWolfe, BackTracking
160162
## Trust Region Algorithms
161163
export RadiusUpdateSchemes
162164

src/algorithms/klement.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@ over this.
2525
differentiable problems.
2626
"""
2727
function Klement(; max_resets::Int = 100, linsolve = nothing, alpha = nothing,
28-
linesearch = NoLineSearch(), precs = DEFAULT_PRECS, autodiff = nothing,
29-
init_jacobian::Val{IJ} = Val(:identity)) where {IJ}
28+
linesearch = NoLineSearch(), precs = DEFAULT_PRECS,
29+
autodiff = nothing, init_jacobian::Val{IJ} = Val(:identity)) where {IJ}
3030
if !(linesearch isa AbstractNonlinearSolveLineSearchAlgorithm)
3131
Base.depwarn(
3232
"Passing in a `LineSearches.jl` algorithm directly is deprecated. \

src/algorithms/lbroyden.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -159,8 +159,8 @@ function LinearAlgebra.mul!(y::AbstractVector, x::AbstractVector, J::BroydenLowR
159159
return y
160160
end
161161

162-
function LinearAlgebra.mul!(
163-
J::BroydenLowRankJacobian, u, vᵀ::LinearAlgebra.AdjOrTransAbsVec, α::Bool, β::Bool)
162+
function LinearAlgebra.mul!(J::BroydenLowRankJacobian, u::AbstractArray,
163+
vᵀ::LinearAlgebra.AdjOrTransAbsVec, α::Bool, β::Bool)
164164
@assert α & β
165165
idx_update = mod1(J.idx + 1, size(J.U, 2))
166166
copyto!(@view(J.U[:, idx_update]), _vec(u))

src/core/approximate_jacobian.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,8 @@ function ApproximateJacobianSolveAlgorithm{concrete_jac, name}(;
6666
linesearch = LineSearchesJL(; method = linesearch)
6767
end
6868
return ApproximateJacobianSolveAlgorithm{concrete_jac, name}(
69-
linesearch, trustregion, descent, update_rule, reinit_rule,
70-
max_resets, max_shrink_times, initialization)
69+
linesearch, trustregion, descent, update_rule,
70+
reinit_rule, max_resets, max_shrink_times, initialization)
7171
end
7272

7373
@inline concrete_jac(::ApproximateJacobianSolveAlgorithm{CJ}) where {CJ} = CJ

src/default.jl

Lines changed: 48 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# Poly Algorithms
22
"""
3-
NonlinearSolvePolyAlgorithm(algs, ::Val{pType} = Val(:NLS)) where {pType}
3+
NonlinearSolvePolyAlgorithm(algs, ::Val{pType} = Val(:NLS);
4+
start_index = 1) where {pType}
45
56
A general way to define PolyAlgorithms for `NonlinearProblem` and
67
`NonlinearLeastSquaresProblem`. This is a container for a tuple of algorithms that will be
@@ -15,6 +16,10 @@ residual is returned.
1516
`NonlinearLeastSquaresProblem`. This is used to determine the correct problem type to
1617
dispatch on.
1718
19+
### Keyword Arguments
20+
21+
- `start_index`: the index to start at. Defaults to `1`.
22+
1823
### Example
1924
2025
```julia
@@ -25,11 +30,14 @@ alg = NonlinearSolvePolyAlgorithm((NewtonRaphson(), Broyden()))
2530
"""
2631
struct NonlinearSolvePolyAlgorithm{pType, N, A} <: AbstractNonlinearSolveAlgorithm{:PolyAlg}
2732
algs::A
33+
start_index::Int
2834

29-
function NonlinearSolvePolyAlgorithm(algs, ::Val{pType} = Val(:NLS)) where {pType}
35+
function NonlinearSolvePolyAlgorithm(
36+
algs, ::Val{pType} = Val(:NLS); start_index::Int = 1) where {pType}
3037
@assert pType (:NLS, :NLLS)
38+
@assert 0 < start_index length(algs)
3139
algs = Tuple(algs)
32-
return new{pType, length(algs), typeof(algs)}(algs)
40+
return new{pType, length(algs), typeof(algs)}(algs, start_index)
3341
end
3442
end
3543

@@ -73,7 +81,7 @@ end
7381

7482
function reinit_cache!(cache::NonlinearSolvePolyAlgorithmCache, args...; kwargs...)
7583
foreach(c -> reinit_cache!(c, args...; kwargs...), cache.caches)
76-
cache.current = 1
84+
cache.current = cache.alg.start_index
7785
cache.nsteps = 0
7886
cache.total_time = 0.0
7987
end
@@ -91,7 +99,7 @@ for (probType, pType) in ((:NonlinearProblem, :NLS), (:NonlinearLeastSquaresProb
9199
alg.algs),
92100
alg,
93101
-1,
94-
1,
102+
alg.start_index,
95103
0,
96104
0.0,
97105
maxtime,
@@ -134,7 +142,7 @@ end
134142

135143
resids = map(x -> Symbol("$(x)_resid"), cache_syms)
136144
for (sym, resid) in zip(cache_syms, resids)
137-
push!(calls, :($(resid) = get_fu($(sym))))
145+
push!(calls, :($(resid) = @isdefined($(sym)) ? get_fu($(sym)) : nothing))
138146
end
139147
push!(calls,
140148
quote
@@ -194,25 +202,29 @@ for (probType, pType) in ((:NonlinearProblem, :NLS), (:NonlinearLeastSquaresProb
194202
@eval begin
195203
@generated function SciMLBase.__solve(
196204
prob::$probType, alg::$algType{N}, args...; kwargs...) where {N}
197-
calls = []
205+
calls = [:(current = alg.start_index)]
198206
sol_syms = [gensym("sol") for _ in 1:N]
199207
for i in 1:N
200208
cur_sol = sol_syms[i]
201209
push!(calls,
202210
quote
203-
$(cur_sol) = SciMLBase.__solve(prob, alg.algs[$(i)], args...; kwargs...)
204-
if SciMLBase.successful_retcode($(cur_sol))
205-
return SciMLBase.build_solution(
206-
prob, alg, $(cur_sol).u, $(cur_sol).resid;
207-
$(cur_sol).retcode, $(cur_sol).stats,
208-
original = $(cur_sol), trace = $(cur_sol).trace)
211+
if current == $i
212+
$(cur_sol) = SciMLBase.__solve(
213+
prob, alg.algs[$(i)], args...; kwargs...)
214+
if SciMLBase.successful_retcode($(cur_sol))
215+
return SciMLBase.build_solution(
216+
prob, alg, $(cur_sol).u, $(cur_sol).resid;
217+
$(cur_sol).retcode, $(cur_sol).stats,
218+
original = $(cur_sol), trace = $(cur_sol).trace)
219+
end
220+
current = $(i + 1)
209221
end
210222
end)
211223
end
212224

213225
resids = map(x -> Symbol("$(x)_resid"), sol_syms)
214226
for (sym, resid) in zip(sol_syms, resids)
215-
push!(calls, :($(resid) = $(sym).resid))
227+
push!(calls, :($(resid) = @isdefined($(sym)) ? $(sym).resid : nothing))
216228
end
217229

218230
push!(calls, quote
@@ -263,6 +275,7 @@ function RobustMultiNewton(::Type{T} = Float64; concrete_jac = nothing, linsolve
263275
algs = (TrustRegion(; concrete_jac, linsolve, precs, autodiff),
264276
TrustRegion(; concrete_jac, linsolve, precs, autodiff,
265277
radius_update_scheme = RadiusUpdateSchemes.Bastin),
278+
NewtonRaphson(; concrete_jac, linsolve, precs, autodiff),
266279
NewtonRaphson(; concrete_jac, linsolve, precs,
267280
linesearch = LineSearchesJL(; method = BackTracking()), autodiff),
268281
TrustRegion(; concrete_jac, linsolve, precs,
@@ -276,7 +289,8 @@ end
276289
"""
277290
FastShortcutNonlinearPolyalg(::Type{T} = Float64; concrete_jac = nothing,
278291
linsolve = nothing, precs = DEFAULT_PRECS, must_use_jacobian::Val = Val(false),
279-
prefer_simplenonlinearsolve::Val{SA} = Val(false), autodiff = nothing) where {T}
292+
prefer_simplenonlinearsolve::Val{SA} = Val(false), autodiff = nothing,
293+
u0_len::Union{Int, Nothing} = nothing) where {T}
280294
281295
A polyalgorithm focused on balancing speed and robustness. It first tries less robust methods
282296
for more performance and then tries more robust techniques if the faster ones fail.
@@ -285,12 +299,19 @@ for more performance and then tries more robust techniques if the faster ones fa
285299
286300
- `T`: The eltype of the initial guess. It is only used to check if some of the algorithms
287301
are compatible with the problem type. Defaults to `Float64`.
302+
303+
### Keyword Arguments
304+
305+
- `u0_len`: The length of the initial guess. If this is `nothing`, then the length of the
306+
initial guess is not checked. If this is an integer and it is less than `25`, we use
307+
jacobian based methods.
288308
"""
289309
function FastShortcutNonlinearPolyalg(
290310
::Type{T} = Float64; concrete_jac = nothing, linsolve = nothing,
291311
precs = DEFAULT_PRECS, must_use_jacobian::Val{JAC} = Val(false),
292312
prefer_simplenonlinearsolve::Val{SA} = Val(false),
293-
autodiff = nothing) where {T, JAC, SA}
313+
u0_len::Union{Int, Nothing} = nothing, autodiff = nothing) where {T, JAC, SA}
314+
start_index = 1
294315
if JAC
295316
if __is_complex(T)
296317
algs = (NewtonRaphson(; concrete_jac, linsolve, precs, autodiff),)
@@ -312,6 +333,7 @@ function FastShortcutNonlinearPolyalg(
312333
SimpleKlement(),
313334
NewtonRaphson(; concrete_jac, linsolve, precs, autodiff))
314335
else
336+
start_index = u0_len !== nothing ? (u0_len 25 ? 4 : 1) : 1
315337
algs = (SimpleBroyden(),
316338
Broyden(; init_jacobian = Val(:true_jacobian), autodiff),
317339
SimpleKlement(),
@@ -327,6 +349,8 @@ function FastShortcutNonlinearPolyalg(
327349
Klement(; linsolve, precs, autodiff),
328350
NewtonRaphson(; concrete_jac, linsolve, precs, autodiff))
329351
else
352+
# TODO: This number requires a bit rigorous testing
353+
start_index = u0_len !== nothing ? (u0_len 25 ? 4 : 1) : 1
330354
algs = (Broyden(; autodiff),
331355
Broyden(; init_jacobian = Val(:true_jacobian), autodiff),
332356
Klement(; linsolve, precs, autodiff),
@@ -339,7 +363,7 @@ function FastShortcutNonlinearPolyalg(
339363
end
340364
end
341365
end
342-
return NonlinearSolvePolyAlgorithm(algs, Val(:NLS))
366+
return NonlinearSolvePolyAlgorithm(algs, Val(:NLS); start_index)
343367
end
344368

345369
"""
@@ -392,17 +416,19 @@ end
392416
## can use that!
393417
function SciMLBase.__init(prob::NonlinearProblem, ::Nothing, args...; kwargs...)
394418
must_use_jacobian = Val(prob.f.jac !== nothing)
395-
return SciMLBase.__init(
396-
prob, FastShortcutNonlinearPolyalg(eltype(prob.u0); must_use_jacobian),
397-
args...; kwargs...)
419+
return SciMLBase.__init(prob,
420+
FastShortcutNonlinearPolyalg(
421+
eltype(prob.u0); must_use_jacobian, u0_len = length(prob.u0)),
422+
args...;
423+
kwargs...)
398424
end
399425

400426
function SciMLBase.__solve(prob::NonlinearProblem, ::Nothing, args...; kwargs...)
401427
must_use_jacobian = Val(prob.f.jac !== nothing)
402428
prefer_simplenonlinearsolve = Val(prob.u0 isa SArray)
403429
return SciMLBase.__solve(prob,
404-
FastShortcutNonlinearPolyalg(
405-
eltype(prob.u0); must_use_jacobian, prefer_simplenonlinearsolve),
430+
FastShortcutNonlinearPolyalg(eltype(prob.u0); must_use_jacobian,
431+
prefer_simplenonlinearsolve, u0_len = length(prob.u0)),
406432
args...;
407433
kwargs...)
408434
end

src/globalization/line_search.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,10 @@ end
5353

5454
LineSearchesJL(method; kwargs...) = LineSearchesJL(; method, kwargs...)
5555
function LineSearchesJL(; method = LineSearches.Static(), autodiff = nothing, α = true)
56+
if method isa LineSearchesJL # Prevent breaking old code
57+
return LineSearchesJL(method.method, α, autodiff)
58+
end
59+
5660
if method isa AbstractNonlinearSolveLineSearchAlgorithm
5761
Base.depwarn("Passing a native NonlinearSolve line search algorithm to \
5862
`LineSearchesJL` or `LineSearch` is deprecated. Pass the method \
@@ -65,6 +69,18 @@ end
6569

6670
Base.@deprecate_binding LineSearch LineSearchesJL true
6771

72+
Static(args...; kwargs...) = LineSearchesJL(LineSearches.Static(args...; kwargs...))
73+
HagerZhang(args...; kwargs...) = LineSearchesJL(LineSearches.HagerZhang(args...; kwargs...))
74+
function MoreThuente(args...; kwargs...)
75+
return LineSearchesJL(LineSearches.MoreThuente(args...; kwargs...))
76+
end
77+
function BackTracking(args...; kwargs...)
78+
return LineSearchesJL(LineSearches.BackTracking(args...; kwargs...))
79+
end
80+
function StrongWolfe(args...; kwargs...)
81+
return LineSearchesJL(LineSearches.StrongWolfe(args...; kwargs...))
82+
end
83+
6884
# Wrapper over LineSearches.jl algorithms
6985
@concrete mutable struct LineSearchesJLCache <: AbstractNonlinearSolveLineSearchCache
7086
f

src/internal/jacobian.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,9 @@ function JacobianCache(
8383
JacobianOperator(prob, fu, u; jvp_autodiff, vjp_autodiff)
8484
else
8585
if has_analytic_jac
86-
f.jac_prototype === nothing ? undefmatrix(u) : f.jac_prototype
86+
f.jac_prototype === nothing ?
87+
similar(fu, promote_type(eltype(fu), eltype(u)), length(fu), length(u)) :
88+
copy(f.jac_prototype)
8789
elseif f.jac_prototype === nothing
8890
init_jacobian(jac_cache; preserve_immutable = Val(true))
8991
else

src/utils.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ function __findmin_caches(f, caches)
9999
end
100100
function __findmin(f, x)
101101
return findmin(x) do xᵢ
102+
xᵢ === nothing && return Inf
102103
fx = f(xᵢ)
103104
return isnan(fx) ? Inf : fx
104105
end

test/core/23_test_problems_tests.jl

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,16 @@ end
4040
export test_on_library, problems, dicts
4141
end
4242

43+
@testitem "PolyAlgorithms" setup=[RobustnessTesting] begin
44+
alg_ops = (RobustMultiNewton(), FastShortcutNonlinearPolyalg())
45+
46+
broken_tests = Dict(alg => Int[] for alg in alg_ops)
47+
broken_tests[alg_ops[1]] = []
48+
broken_tests[alg_ops[2]] = []
49+
50+
test_on_library(problems, dicts, alg_ops, broken_tests)
51+
end
52+
4353
@testitem "NewtonRaphson" setup=[RobustnessTesting] begin
4454
alg_ops = (NewtonRaphson(),)
4555

@@ -91,7 +101,7 @@ end
91101
test_on_library(problems, dicts, alg_ops, broken_tests)
92102
end
93103

94-
@testitem "Broyden" retries=5 setup=[RobustnessTesting] begin
104+
@testitem "Broyden" setup=[RobustnessTesting] begin
95105
alg_ops = (Broyden(), Broyden(; init_jacobian = Val(:true_jacobian)),
96106
Broyden(; update_rule = Val(:bad_broyden)),
97107
Broyden(; init_jacobian = Val(:true_jacobian), update_rule = Val(:bad_broyden)))

0 commit comments

Comments
 (0)