Skip to content

Commit 6f68b2a

Browse files
committed
Use start index
1 parent 5911835 commit 6f68b2a

File tree

6 files changed

+62
-34
lines changed

6 files changed

+62
-34
lines changed

src/algorithms/klement.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@ over this.
2525
differentiable problems.
2626
"""
2727
function Klement(; max_resets::Int = 100, linsolve = nothing, alpha = nothing,
28-
linesearch = NoLineSearch(), precs = DEFAULT_PRECS, autodiff = nothing,
29-
init_jacobian::Val{IJ} = Val(:identity)) where {IJ}
28+
linesearch = NoLineSearch(), precs = DEFAULT_PRECS,
29+
autodiff = nothing, init_jacobian::Val{IJ} = Val(:identity)) where {IJ}
3030
if !(linesearch isa AbstractNonlinearSolveLineSearchAlgorithm)
3131
Base.depwarn(
3232
"Passing in a `LineSearches.jl` algorithm directly is deprecated. \

src/core/approximate_jacobian.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,8 @@ function ApproximateJacobianSolveAlgorithm{concrete_jac, name}(;
6666
linesearch = LineSearchesJL(; method = linesearch)
6767
end
6868
return ApproximateJacobianSolveAlgorithm{concrete_jac, name}(
69-
linesearch, trustregion, descent, update_rule, reinit_rule,
70-
max_resets, max_shrink_times, initialization)
69+
linesearch, trustregion, descent, update_rule,
70+
reinit_rule, max_resets, max_shrink_times, initialization)
7171
end
7272

7373
@inline concrete_jac(::ApproximateJacobianSolveAlgorithm{CJ}) where {CJ} = CJ

src/default.jl

Lines changed: 48 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# Poly Algorithms
22
"""
3-
NonlinearSolvePolyAlgorithm(algs, ::Val{pType} = Val(:NLS)) where {pType}
3+
NonlinearSolvePolyAlgorithm(algs, ::Val{pType} = Val(:NLS);
4+
start_index = 1) where {pType}
45
56
A general way to define PolyAlgorithms for `NonlinearProblem` and
67
`NonlinearLeastSquaresProblem`. This is a container for a tuple of algorithms that will be
@@ -15,6 +16,10 @@ residual is returned.
1516
`NonlinearLeastSquaresProblem`. This is used to determine the correct problem type to
1617
dispatch on.
1718
19+
### Keyword Arguments
20+
21+
- `start_index`: the index to start at. Defaults to `1`.
22+
1823
### Example
1924
2025
```julia
@@ -25,11 +30,14 @@ alg = NonlinearSolvePolyAlgorithm((NewtonRaphson(), Broyden()))
2530
"""
2631
struct NonlinearSolvePolyAlgorithm{pType, N, A} <: AbstractNonlinearSolveAlgorithm{:PolyAlg}
2732
algs::A
33+
start_index::Int
2834

29-
function NonlinearSolvePolyAlgorithm(algs, ::Val{pType} = Val(:NLS)) where {pType}
35+
function NonlinearSolvePolyAlgorithm(
36+
algs, ::Val{pType} = Val(:NLS); start_index::Int = 1) where {pType}
3037
@assert pType (:NLS, :NLLS)
38+
@assert 0 < start_index length(algs)
3139
algs = Tuple(algs)
32-
return new{pType, length(algs), typeof(algs)}(algs)
40+
return new{pType, length(algs), typeof(algs)}(algs, start_index)
3341
end
3442
end
3543

@@ -73,7 +81,7 @@ end
7381

7482
function reinit_cache!(cache::NonlinearSolvePolyAlgorithmCache, args...; kwargs...)
7583
foreach(c -> reinit_cache!(c, args...; kwargs...), cache.caches)
76-
cache.current = 1
84+
cache.current = cache.alg.start_index
7785
cache.nsteps = 0
7886
cache.total_time = 0.0
7987
end
@@ -91,7 +99,7 @@ for (probType, pType) in ((:NonlinearProblem, :NLS), (:NonlinearLeastSquaresProb
9199
alg.algs),
92100
alg,
93101
-1,
94-
1,
102+
alg.start_index,
95103
0,
96104
0.0,
97105
maxtime,
@@ -140,6 +148,7 @@ end
140148
quote
141149
fus = tuple($(Tuple(resids)...))
142150
minfu, idx = __findmin(cache.internalnorm, fus)
151+
idx += cache.alg.start_index - 1
143152
stats = __compile_stats(cache.caches[idx])
144153
u = get_u(cache.caches[idx])
145154
retcode = cache.caches[idx].retcode
@@ -194,18 +203,22 @@ for (probType, pType) in ((:NonlinearProblem, :NLS), (:NonlinearLeastSquaresProb
194203
@eval begin
195204
@generated function SciMLBase.__solve(
196205
prob::$probType, alg::$algType{N}, args...; kwargs...) where {N}
197-
calls = []
206+
calls = [:(current = alg.start_index)]
198207
sol_syms = [gensym("sol") for _ in 1:N]
199208
for i in 1:N
200209
cur_sol = sol_syms[i]
201210
push!(calls,
202211
quote
203-
$(cur_sol) = SciMLBase.__solve(prob, alg.algs[$(i)], args...; kwargs...)
204-
if SciMLBase.successful_retcode($(cur_sol))
205-
return SciMLBase.build_solution(
206-
prob, alg, $(cur_sol).u, $(cur_sol).resid;
207-
$(cur_sol).retcode, $(cur_sol).stats,
208-
original = $(cur_sol), trace = $(cur_sol).trace)
212+
if current == $i
213+
$(cur_sol) = SciMLBase.__solve(
214+
prob, alg.algs[$(i)], args...; kwargs...)
215+
if SciMLBase.successful_retcode($(cur_sol))
216+
return SciMLBase.build_solution(
217+
prob, alg, $(cur_sol).u, $(cur_sol).resid;
218+
$(cur_sol).retcode, $(cur_sol).stats,
219+
original = $(cur_sol), trace = $(cur_sol).trace)
220+
end
221+
current = $(i + 1)
209222
end
210223
end)
211224
end
@@ -218,6 +231,7 @@ for (probType, pType) in ((:NonlinearProblem, :NLS), (:NonlinearLeastSquaresProb
218231
push!(calls, quote
219232
resids = tuple($(Tuple(resids)...))
220233
minfu, idx = __findmin(DEFAULT_NORM, resids)
234+
idx += alg.start_index - 1
221235
end)
222236

223237
for i in 1:N
@@ -263,6 +277,7 @@ function RobustMultiNewton(::Type{T} = Float64; concrete_jac = nothing, linsolve
263277
algs = (TrustRegion(; concrete_jac, linsolve, precs, autodiff),
264278
TrustRegion(; concrete_jac, linsolve, precs, autodiff,
265279
radius_update_scheme = RadiusUpdateSchemes.Bastin),
280+
NewtonRaphson(; concrete_jac, linsolve, precs, autodiff),
266281
NewtonRaphson(; concrete_jac, linsolve, precs,
267282
linesearch = LineSearchesJL(; method = BackTracking()), autodiff),
268283
TrustRegion(; concrete_jac, linsolve, precs,
@@ -276,7 +291,8 @@ end
276291
"""
277292
FastShortcutNonlinearPolyalg(::Type{T} = Float64; concrete_jac = nothing,
278293
linsolve = nothing, precs = DEFAULT_PRECS, must_use_jacobian::Val = Val(false),
279-
prefer_simplenonlinearsolve::Val{SA} = Val(false), autodiff = nothing) where {T}
294+
prefer_simplenonlinearsolve::Val{SA} = Val(false), autodiff = nothing,
295+
u0_len::Union{Int, Nothing} = nothing) where {T}
280296
281297
A polyalgorithm focused on balancing speed and robustness. It first tries less robust methods
282298
for more performance and then tries more robust techniques if the faster ones fail.
@@ -285,12 +301,19 @@ for more performance and then tries more robust techniques if the faster ones fa
285301
286302
- `T`: The eltype of the initial guess. It is only used to check if some of the algorithms
287303
are compatible with the problem type. Defaults to `Float64`.
304+
305+
### Keyword Arguments
306+
307+
- `u0_len`: The length of the initial guess. If this is `nothing`, then the length of the
308+
initial guess is not checked. If this is an integer and it is less than `25`, we use
309+
jacobian based methods.
288310
"""
289311
function FastShortcutNonlinearPolyalg(
290312
::Type{T} = Float64; concrete_jac = nothing, linsolve = nothing,
291313
precs = DEFAULT_PRECS, must_use_jacobian::Val{JAC} = Val(false),
292314
prefer_simplenonlinearsolve::Val{SA} = Val(false),
293-
autodiff = nothing) where {T, JAC, SA}
315+
u0_len::Union{Int, Nothing} = nothing, autodiff = nothing) where {T, JAC, SA}
316+
start_index = 1
294317
if JAC
295318
if __is_complex(T)
296319
algs = (NewtonRaphson(; concrete_jac, linsolve, precs, autodiff),)
@@ -312,6 +335,7 @@ function FastShortcutNonlinearPolyalg(
312335
SimpleKlement(),
313336
NewtonRaphson(; concrete_jac, linsolve, precs, autodiff))
314337
else
338+
start_index = u0_len !== nothing ? (u0_len 25 ? 4 : 1) : 1
315339
algs = (SimpleBroyden(),
316340
Broyden(; init_jacobian = Val(:true_jacobian), autodiff),
317341
SimpleKlement(),
@@ -327,6 +351,8 @@ function FastShortcutNonlinearPolyalg(
327351
Klement(; linsolve, precs, autodiff),
328352
NewtonRaphson(; concrete_jac, linsolve, precs, autodiff))
329353
else
354+
# TODO: This number requires a bit rigorous testing
355+
start_index = u0_len !== nothing ? (u0_len 25 ? 4 : 1) : 1
330356
algs = (Broyden(; autodiff),
331357
Broyden(; init_jacobian = Val(:true_jacobian), autodiff),
332358
Klement(; linsolve, precs, autodiff),
@@ -339,7 +365,7 @@ function FastShortcutNonlinearPolyalg(
339365
end
340366
end
341367
end
342-
return NonlinearSolvePolyAlgorithm(algs, Val(:NLS))
368+
return NonlinearSolvePolyAlgorithm(algs, Val(:NLS); start_index)
343369
end
344370

345371
"""
@@ -392,17 +418,19 @@ end
392418
## can use that!
393419
function SciMLBase.__init(prob::NonlinearProblem, ::Nothing, args...; kwargs...)
394420
must_use_jacobian = Val(prob.f.jac !== nothing)
395-
return SciMLBase.__init(
396-
prob, FastShortcutNonlinearPolyalg(eltype(prob.u0); must_use_jacobian),
397-
args...; kwargs...)
421+
return SciMLBase.__init(prob,
422+
FastShortcutNonlinearPolyalg(
423+
eltype(prob.u0); must_use_jacobian, u0_len = length(prob.u0)),
424+
args...;
425+
kwargs...)
398426
end
399427

400428
function SciMLBase.__solve(prob::NonlinearProblem, ::Nothing, args...; kwargs...)
401429
must_use_jacobian = Val(prob.f.jac !== nothing)
402430
prefer_simplenonlinearsolve = Val(prob.u0 isa SArray)
403431
return SciMLBase.__solve(prob,
404-
FastShortcutNonlinearPolyalg(
405-
eltype(prob.u0); must_use_jacobian, prefer_simplenonlinearsolve),
432+
FastShortcutNonlinearPolyalg(eltype(prob.u0); must_use_jacobian,
433+
prefer_simplenonlinearsolve, u0_len = length(prob.u0)),
406434
args...;
407435
kwargs...)
408436
end

test/core/forward_ad_tests.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,8 @@ end
6464

6565
@testitem "ForwardDiff.jl Integration" setup=[ForwardADTesting] begin
6666
for alg in (NewtonRaphson(), TrustRegion(), LevenbergMarquardt(),
67-
PseudoTransient(; alpha_initial = 10.0), Broyden(), Klement(), DFSane(), nothing,
68-
NLsolveJL(), CMINPACK(), KINSOL(; globalization_strategy = :LineSearch))
67+
PseudoTransient(; alpha_initial = 10.0), Broyden(), Klement(), DFSane(),
68+
nothing, NLsolveJL(), CMINPACK(), KINSOL(; globalization_strategy = :LineSearch))
6969
us = (2.0, @SVector[1.0, 1.0], [1.0, 1.0], ones(2, 2), @SArray ones(2, 2))
7070

7171
@testset "Scalar AD" begin

test/core/rootfind_tests.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -476,8 +476,8 @@ end
476476

477477
@testitem "Broyden" setup=[CoreRootfindTesting] begin
478478
@testset "LineSearch: $(_nameof(lsmethod)) LineSearch AD: $(_nameof(ad)) Init Jacobian: $(init_jacobian) Update Rule: $(update_rule)" for lsmethod in (
479-
Static(), StrongWolfe(), BackTracking(), HagerZhang(),
480-
MoreThuente(), LiFukushimaLineSearch()),
479+
Static(), StrongWolfe(), BackTracking(),
480+
HagerZhang(), MoreThuente(), LiFukushimaLineSearch()),
481481
ad in (AutoFiniteDiff(), AutoZygote()),
482482
init_jacobian in (Val(:identity), Val(:true_jacobian)),
483483
update_rule in (Val(:good_broyden), Val(:bad_broyden), Val(:diagonal))
@@ -575,8 +575,8 @@ end
575575

576576
@testitem "LimitedMemoryBroyden" setup=[CoreRootfindTesting] begin
577577
@testset "LineSearch: $(_nameof(lsmethod)) LineSearch AD: $(_nameof(ad))" for lsmethod in (
578-
Static(), StrongWolfe(), BackTracking(), HagerZhang(),
579-
MoreThuente(), LiFukushimaLineSearch()),
578+
Static(), StrongWolfe(), BackTracking(),
579+
HagerZhang(), MoreThuente(), LiFukushimaLineSearch()),
580580
ad in (AutoFiniteDiff(), AutoZygote())
581581

582582
linesearch = LineSearchesJL(; method = lsmethod, autodiff = ad)

test/misc/matrix_resizing_tests.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@
77
vecprob = NonlinearProblem(ff, vec(u0), p)
88
prob = NonlinearProblem(ff, u0, p)
99

10-
for alg in (NewtonRaphson(), TrustRegion(), LevenbergMarquardt(), PseudoTransient(),
11-
RobustMultiNewton(), FastShortcutNonlinearPolyalg(),
10+
for alg in (NewtonRaphson(), TrustRegion(), LevenbergMarquardt(),
11+
PseudoTransient(), RobustMultiNewton(), FastShortcutNonlinearPolyalg(),
1212
Broyden(), Klement(), LimitedMemoryBroyden(; threshold = 2))
1313
@test vec(solve(prob, alg).u) == solve(vecprob, alg).u
1414
end
@@ -23,8 +23,8 @@ end
2323
vecprob = NonlinearProblem(fiip, vec(u0), p)
2424
prob = NonlinearProblem(fiip, u0, p)
2525

26-
for alg in (NewtonRaphson(), TrustRegion(), LevenbergMarquardt(), PseudoTransient(),
27-
RobustMultiNewton(), FastShortcutNonlinearPolyalg(),
26+
for alg in (NewtonRaphson(), TrustRegion(), LevenbergMarquardt(),
27+
PseudoTransient(), RobustMultiNewton(), FastShortcutNonlinearPolyalg(),
2828
Broyden(), Klement(), LimitedMemoryBroyden(; threshold = 2))
2929
@test vec(solve(prob, alg).u) == solve(vecprob, alg).u
3030
end

0 commit comments

Comments
 (0)