1
1
# Poly Algorithms
2
2
"""
3
- NonlinearSolvePolyAlgorithm(algs, ::Val{pType} = Val(:NLS)) where {pType}
3
+ NonlinearSolvePolyAlgorithm(algs, ::Val{pType} = Val(:NLS);
4
+ start_index = 1) where {pType}
4
5
5
6
A general way to define PolyAlgorithms for `NonlinearProblem` and
6
7
`NonlinearLeastSquaresProblem`. This is a container for a tuple of algorithms that will be
@@ -15,6 +16,10 @@ residual is returned.
15
16
`NonlinearLeastSquaresProblem`. This is used to determine the correct problem type to
16
17
dispatch on.
17
18
19
+ ### Keyword Arguments
20
+
21
+ - `start_index`: the index to start at. Defaults to `1`.
22
+
18
23
### Example
19
24
20
25
```julia
@@ -25,11 +30,14 @@ alg = NonlinearSolvePolyAlgorithm((NewtonRaphson(), Broyden()))
25
30
"""
26
31
struct NonlinearSolvePolyAlgorithm{pType, N, A} <: AbstractNonlinearSolveAlgorithm{:PolyAlg}
27
32
algs:: A
33
+ start_index:: Int
28
34
29
- function NonlinearSolvePolyAlgorithm (algs, :: Val{pType} = Val (:NLS )) where {pType}
35
+ function NonlinearSolvePolyAlgorithm (
36
+ algs, :: Val{pType} = Val (:NLS ); start_index:: Int = 1 ) where {pType}
30
37
@assert pType ∈ (:NLS , :NLLS )
38
+ @assert 0 < start_index ≤ length (algs)
31
39
algs = Tuple (algs)
32
- return new {pType, length(algs), typeof(algs)} (algs)
40
+ return new {pType, length(algs), typeof(algs)} (algs, start_index )
33
41
end
34
42
end
35
43
73
81
74
82
function reinit_cache! (cache:: NonlinearSolvePolyAlgorithmCache , args... ; kwargs... )
75
83
foreach (c -> reinit_cache! (c, args... ; kwargs... ), cache. caches)
76
- cache. current = 1
84
+ cache. current = cache . alg . start_index
77
85
cache. nsteps = 0
78
86
cache. total_time = 0.0
79
87
end
@@ -91,7 +99,7 @@ for (probType, pType) in ((:NonlinearProblem, :NLS), (:NonlinearLeastSquaresProb
91
99
alg. algs),
92
100
alg,
93
101
- 1 ,
94
- 1 ,
102
+ alg . start_index ,
95
103
0 ,
96
104
0.0 ,
97
105
maxtime,
134
142
135
143
resids = map (x -> Symbol (" $(x) _resid" ), cache_syms)
136
144
for (sym, resid) in zip (cache_syms, resids)
137
- push! (calls, :($ (resid) = get_fu ($ (sym))))
145
+ push! (calls, :($ (resid) = @isdefined ( $ (sym)) ? get_fu ($ (sym)) : nothing ))
138
146
end
139
147
push! (calls,
140
148
quote
@@ -194,25 +202,29 @@ for (probType, pType) in ((:NonlinearProblem, :NLS), (:NonlinearLeastSquaresProb
194
202
@eval begin
195
203
@generated function SciMLBase. __solve (
196
204
prob:: $probType , alg:: $algType{N} , args... ; kwargs... ) where {N}
197
- calls = []
205
+ calls = [:(current = alg . start_index) ]
198
206
sol_syms = [gensym (" sol" ) for _ in 1 : N]
199
207
for i in 1 : N
200
208
cur_sol = sol_syms[i]
201
209
push! (calls,
202
210
quote
203
- $ (cur_sol) = SciMLBase. __solve (prob, alg. algs[$ (i)], args... ; kwargs... )
204
- if SciMLBase. successful_retcode ($ (cur_sol))
205
- return SciMLBase. build_solution (
206
- prob, alg, $ (cur_sol). u, $ (cur_sol). resid;
207
- $ (cur_sol). retcode, $ (cur_sol). stats,
208
- original = $ (cur_sol), trace = $ (cur_sol). trace)
211
+ if current == $ i
212
+ $ (cur_sol) = SciMLBase. __solve (
213
+ prob, alg. algs[$ (i)], args... ; kwargs... )
214
+ if SciMLBase. successful_retcode ($ (cur_sol))
215
+ return SciMLBase. build_solution (
216
+ prob, alg, $ (cur_sol). u, $ (cur_sol). resid;
217
+ $ (cur_sol). retcode, $ (cur_sol). stats,
218
+ original = $ (cur_sol), trace = $ (cur_sol). trace)
219
+ end
220
+ current = $ (i + 1 )
209
221
end
210
222
end )
211
223
end
212
224
213
225
resids = map (x -> Symbol (" $(x) _resid" ), sol_syms)
214
226
for (sym, resid) in zip (sol_syms, resids)
215
- push! (calls, :($ (resid) = $ (sym). resid))
227
+ push! (calls, :($ (resid) = @isdefined ( $ (sym)) ? $ (sym) . resid : nothing ))
216
228
end
217
229
218
230
push! (calls, quote
@@ -263,6 +275,7 @@ function RobustMultiNewton(::Type{T} = Float64; concrete_jac = nothing, linsolve
263
275
algs = (TrustRegion (; concrete_jac, linsolve, precs, autodiff),
264
276
TrustRegion (; concrete_jac, linsolve, precs, autodiff,
265
277
radius_update_scheme = RadiusUpdateSchemes. Bastin),
278
+ NewtonRaphson (; concrete_jac, linsolve, precs, autodiff),
266
279
NewtonRaphson (; concrete_jac, linsolve, precs,
267
280
linesearch = LineSearchesJL (; method = BackTracking ()), autodiff),
268
281
TrustRegion (; concrete_jac, linsolve, precs,
276
289
"""
277
290
FastShortcutNonlinearPolyalg(::Type{T} = Float64; concrete_jac = nothing,
278
291
linsolve = nothing, precs = DEFAULT_PRECS, must_use_jacobian::Val = Val(false),
279
- prefer_simplenonlinearsolve::Val{SA} = Val(false), autodiff = nothing) where {T}
292
+ prefer_simplenonlinearsolve::Val{SA} = Val(false), autodiff = nothing,
293
+ u0_len::Union{Int, Nothing} = nothing) where {T}
280
294
281
295
A polyalgorithm focused on balancing speed and robustness. It first tries less robust methods
282
296
for more performance and then tries more robust techniques if the faster ones fail.
@@ -285,12 +299,19 @@ for more performance and then tries more robust techniques if the faster ones fa
285
299
286
300
- `T`: The eltype of the initial guess. It is only used to check if some of the algorithms
287
301
are compatible with the problem type. Defaults to `Float64`.
302
+
303
+ ### Keyword Arguments
304
+
305
+ - `u0_len`: The length of the initial guess. If this is `nothing`, then the length of the
306
+ initial guess is not checked. If this is an integer and it is less than `25`, we use
307
+ jacobian based methods.
288
308
"""
289
309
function FastShortcutNonlinearPolyalg (
290
310
:: Type{T} = Float64; concrete_jac = nothing , linsolve = nothing ,
291
311
precs = DEFAULT_PRECS, must_use_jacobian:: Val{JAC} = Val (false ),
292
312
prefer_simplenonlinearsolve:: Val{SA} = Val (false ),
293
- autodiff = nothing ) where {T, JAC, SA}
313
+ u0_len:: Union{Int, Nothing} = nothing , autodiff = nothing ) where {T, JAC, SA}
314
+ start_index = 1
294
315
if JAC
295
316
if __is_complex (T)
296
317
algs = (NewtonRaphson (; concrete_jac, linsolve, precs, autodiff),)
@@ -312,6 +333,7 @@ function FastShortcutNonlinearPolyalg(
312
333
SimpleKlement (),
313
334
NewtonRaphson (; concrete_jac, linsolve, precs, autodiff))
314
335
else
336
+ start_index = u0_len != = nothing ? (u0_len ≤ 25 ? 4 : 1 ) : 1
315
337
algs = (SimpleBroyden (),
316
338
Broyden (; init_jacobian = Val (:true_jacobian ), autodiff),
317
339
SimpleKlement (),
@@ -327,6 +349,8 @@ function FastShortcutNonlinearPolyalg(
327
349
Klement (; linsolve, precs, autodiff),
328
350
NewtonRaphson (; concrete_jac, linsolve, precs, autodiff))
329
351
else
352
+ # TODO : This number requires a bit rigorous testing
353
+ start_index = u0_len != = nothing ? (u0_len ≤ 25 ? 4 : 1 ) : 1
330
354
algs = (Broyden (; autodiff),
331
355
Broyden (; init_jacobian = Val (:true_jacobian ), autodiff),
332
356
Klement (; linsolve, precs, autodiff),
@@ -339,7 +363,7 @@ function FastShortcutNonlinearPolyalg(
339
363
end
340
364
end
341
365
end
342
- return NonlinearSolvePolyAlgorithm (algs, Val (:NLS ))
366
+ return NonlinearSolvePolyAlgorithm (algs, Val (:NLS ); start_index )
343
367
end
344
368
345
369
"""
@@ -392,17 +416,19 @@ end
392
416
# # can use that!
393
417
function SciMLBase. __init (prob:: NonlinearProblem , :: Nothing , args... ; kwargs... )
394
418
must_use_jacobian = Val (prob. f. jac != = nothing )
395
- return SciMLBase. __init (
396
- prob, FastShortcutNonlinearPolyalg (eltype (prob. u0); must_use_jacobian),
397
- args... ; kwargs... )
419
+ return SciMLBase. __init (prob,
420
+ FastShortcutNonlinearPolyalg (
421
+ eltype (prob. u0); must_use_jacobian, u0_len = length (prob. u0)),
422
+ args... ;
423
+ kwargs... )
398
424
end
399
425
400
426
function SciMLBase. __solve (prob:: NonlinearProblem , :: Nothing , args... ; kwargs... )
401
427
must_use_jacobian = Val (prob. f. jac != = nothing )
402
428
prefer_simplenonlinearsolve = Val (prob. u0 isa SArray)
403
429
return SciMLBase. __solve (prob,
404
- FastShortcutNonlinearPolyalg (
405
- eltype (prob. u0); must_use_jacobian, prefer_simplenonlinearsolve ),
430
+ FastShortcutNonlinearPolyalg (eltype (prob . u0); must_use_jacobian,
431
+ prefer_simplenonlinearsolve, u0_len = length (prob. u0)),
406
432
args... ;
407
433
kwargs... )
408
434
end
0 commit comments