Skip to content

Commit 29917d8

Browse files
committed
Add wrapper for FixedPointAcceleration
1 parent 831dacc commit 29917d8

6 files changed

+220
-25
lines changed
Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,56 @@
11
module NonlinearSolveFixedPointAccelerationExt
22

3-
end
3+
using NonlinearSolve, FixedPointAcceleration, DiffEqBase, SciMLBase
4+
5+
function SciMLBase.__solve(prob::NonlinearProblem, alg::FixedPointAccelerationJL, args...;
6+
abstol = nothing, maxiters = 1000, alias_u0::Bool = false,
7+
show_trace::Val{PrintReports} = Val(false), termination_condition = nothing,
8+
kwargs...) where {PrintReports}
9+
@assert (termination_condition ===
10+
nothing)||(termination_condition isa AbsNormTerminationMode) "SpeedMappingJL does not support termination conditions!"
11+
12+
u0 = NonlinearSolve.__maybe_unaliased(prob.u0, alias_u0)
13+
u_size = size(u0)
14+
T = eltype(u0)
15+
iip = isinplace(prob)
16+
p = prob.p
17+
18+
if !iip && prob.u0 isa Number
19+
# FixedPointAcceleration makes the scalar problem into a vector problem
20+
f = (u) -> [prob.f(u[1], p) .+ u[1]]
21+
elseif !iip && prob.u0 isa AbstractVector
22+
f = (u) -> (prob.f(u, p) .+ u)
23+
elseif !iip && prob.u0 isa AbstractArray
24+
f = (u) -> vec(prob.f(reshape(u, u_size), p) .+ u)
25+
elseif iip && prob.u0 isa AbstractVector
26+
du = similar(u0)
27+
f = (u) -> (prob.f(du, u, p); du .+ u)
28+
else
29+
du = similar(u0)
30+
f = (u) -> (prob.f(du, reshape(u, u_size), p); vec(du) .+ u)
31+
end
32+
33+
tol = abstol === nothing ? real(oneunit(T)) * (eps(real(one(T))))^(4 // 5) : abstol
34+
35+
sol = fixed_point(f, NonlinearSolve._vec(u0); Algorithm = alg.algorithm,
36+
ConvergenceMetricThreshold = tol, MaxIter = maxiters, MaxM = alg.m,
37+
ExtrapolationPeriod = alg.extrapolation_period, Dampening = alg.dampening,
38+
PrintReports, ReplaceInvalids = alg.replace_invalids,
39+
ConditionNumberThreshold = alg.condition_number_threshold, quiet_errors = true)
40+
41+
res = prob.u0 isa Number ? first(sol.FixedPoint_) : sol.FixedPoint_
42+
if res === missing
43+
resid = NonlinearSolve.evaluate_f(prob, u0)
44+
res = u0
45+
converged = false
46+
else
47+
resid = NonlinearSolve.evaluate_f(prob, res)
48+
converged = maximum(abs, resid) tol
49+
end
50+
return SciMLBase.build_solution(prob, alg, res, resid;
51+
retcode = converged ? ReturnCode.Success : ReturnCode.Failure,
52+
stats = SciMLBase.NLStats(sol.Iterations_, 0, 0, 0, sol.Iterations_),
53+
original = sol)
54+
end
55+
56+
end

ext/NonlinearSolveMINPACKExt.jl

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ using MINPACK
55

66
function SciMLBase.__solve(prob::Union{NonlinearProblem{uType, iip},
77
NonlinearLeastSquaresProblem{uType, iip}}, alg::CMINPACK, args...;
8-
abstol = 1e-6, maxiters = 100000, alias_u0::Bool = false,
8+
abstol = nothing, maxiters = 100000, alias_u0::Bool = false,
99
termination_condition = nothing, kwargs...) where {uType, iip}
1010
@assert (termination_condition ===
1111
nothing)||(termination_condition isa AbsNormTerminationMode) "CMINPACK does not support termination conditions!"
@@ -16,6 +16,7 @@ function SciMLBase.__solve(prob::Union{NonlinearProblem{uType, iip},
1616
u0 = NonlinearSolve.__maybe_unaliased(prob.u0, alias_u0)
1717
end
1818

19+
T = eltype(u0)
1920
sizeu = size(prob.u0)
2021
p = prob.p
2122

@@ -25,11 +26,11 @@ function SciMLBase.__solve(prob::Union{NonlinearProblem{uType, iip},
2526

2627
if !iip && prob.u0 isa Number
2728
f! = (du, u) -> (du .= prob.f(first(u), p); Cint(0))
28-
elseif !iip && prob.u0 isa Vector{Float64}
29+
elseif !iip && prob.u0 isa AbstractVector
2930
f! = (du, u) -> (du .= prob.f(u, p); Cint(0))
3031
elseif !iip && prob.u0 isa AbstractArray
3132
f! = (du, u) -> (du .= vec(prob.f(reshape(u, sizeu), p)); Cint(0))
32-
elseif prob.u0 isa Vector{Float64}
33+
elseif prob.u0 isa AbstractVector
3334
f! = (du, u) -> prob.f(du, u, p)
3435
else # Then it's an in-place function on an abstract array
3536
f! = (du, u) -> (prob.f(reshape(du, sizeu), reshape(u, sizeu), p); du = vec(du); 0)
@@ -43,14 +44,16 @@ function SciMLBase.__solve(prob::Union{NonlinearProblem{uType, iip},
4344
method = ifelse(alg.method === :auto,
4445
ifelse(prob isa NonlinearLeastSquaresProblem, :lm, :hybr), alg.method)
4546

47+
abstol = abstol === nothing ? real(oneunit(T)) * (eps(real(one(T))))^(4 // 5) : abstol
48+
4649
if SciMLBase.has_jac(prob.f)
4750
if !iip && prob.u0 isa Number
4851
g! = (du, u) -> (du .= prob.f.jac(first(u), p); Cint(0))
49-
elseif !iip && prob.u0 isa Vector{Float64}
52+
elseif !iip && prob.u0 isa AbstractVector
5053
g! = (du, u) -> (du .= prob.f.jac(u, p); Cint(0))
5154
elseif !iip && prob.u0 isa AbstractArray
5255
g! = (du, u) -> (du .= vec(prob.f.jac(reshape(u, sizeu), p)); Cint(0))
53-
elseif prob.u0 isa Vector{Float64}
56+
elseif prob.u0 isa AbstractVector
5457
g! = (du, u) -> prob.f.jac(du, u, p)
5558
else # Then it's an in-place function on an abstract array
5659
g! = function (du, u)

ext/NonlinearSolveNLsolveExt.jl

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,9 @@ module NonlinearSolveNLsolveExt
33
using NonlinearSolve, NLsolve, DiffEqBase, SciMLBase
44
import UnPack: @unpack
55

6-
function SciMLBase.__solve(prob::NonlinearProblem, alg::NLsolveJL, args...; abstol = 1e-6,
7-
maxiters = 1000, alias_u0::Bool = false, termination_condition = nothing, kwargs...)
6+
function SciMLBase.__solve(prob::NonlinearProblem, alg::NLsolveJL, args...;
7+
abstol = nothing, maxiters = 1000, alias_u0::Bool = false,
8+
termination_condition = nothing, kwargs...)
89
@assert (termination_condition ===
910
nothing)||(termination_condition isa AbsNormTerminationMode) "NLsolveJL does not support termination conditions!"
1011

@@ -14,6 +15,7 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::NLsolveJL, args...; abst
1415
u0 = NonlinearSolve.__maybe_unaliased(prob.u0, alias_u0)
1516
end
1617

18+
T = eltype(u0)
1719
iip = isinplace(prob)
1820

1921
sizeu = size(prob.u0)
@@ -25,11 +27,11 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::NLsolveJL, args...; abst
2527

2628
if !iip && prob.u0 isa Number
2729
f! = (du, u) -> (du .= prob.f(first(u), p); Cint(0))
28-
elseif !iip && prob.u0 isa Vector{Float64}
30+
elseif !iip && prob.u0 isa AbstractVector
2931
f! = (du, u) -> (du .= prob.f(u, p); Cint(0))
3032
elseif !iip && prob.u0 isa AbstractArray
3133
f! = (du, u) -> (du .= vec(prob.f(reshape(u, sizeu), p)); Cint(0))
32-
elseif prob.u0 isa Vector{Float64}
34+
elseif prob.u0 isa AbstractVector
3335
f! = (du, u) -> prob.f(du, u, p)
3436
else # Then it's an in-place function on an abstract array
3537
f! = (du, u) -> (prob.f(reshape(du, sizeu), reshape(u, sizeu), p); du = vec(du); 0)
@@ -46,11 +48,11 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::NLsolveJL, args...; abst
4648
if SciMLBase.has_jac(prob.f)
4749
if !iip && prob.u0 isa Number
4850
g! = (du, u) -> (du .= prob.f.jac(first(u), p); Cint(0))
49-
elseif !iip && prob.u0 isa Vector{Float64}
51+
elseif !iip && prob.u0 isa AbstractVector
5052
g! = (du, u) -> (du .= prob.f.jac(u, p); Cint(0))
5153
elseif !iip && prob.u0 isa AbstractArray
5254
g! = (du, u) -> (du .= vec(prob.f.jac(reshape(u, sizeu), p)); Cint(0))
53-
elseif prob.u0 isa Vector{Float64}
55+
elseif prob.u0 isa AbstractVector
5456
g! = (du, u) -> prob.f.jac(du, u, p)
5557
else # Then it's an in-place function on an abstract array
5658
g! = function (du, u)
@@ -68,6 +70,8 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::NLsolveJL, args...; abst
6870
df = OnceDifferentiable(f!, vec(u0), vec(resid); autodiff)
6971
end
7072

73+
abstol = abstol === nothing ? real(oneunit(T)) * (eps(real(one(T))))^(4 // 5) : abstol
74+
7175
original = nlsolve(df, vec(u0); ftol = abstol, iterations = maxiters, method,
7276
store_trace, extended_trace, linesearch, linsolve, factor, autoscale, m, beta,
7377
show_trace)

ext/NonlinearSolveSpeedMappingExt.jl

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
module NonlinearSolveSpeedMappingExt
22

33
using NonlinearSolve, SpeedMapping, DiffEqBase, SciMLBase
4-
import UnPack: @unpack
54

65
function SciMLBase.__solve(prob::NonlinearProblem, alg::SpeedMappingJL, args...;
76
abstol = nothing, maxiters = 1000, alias_u0::Bool = false,
@@ -20,12 +19,6 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::SpeedMappingJL, args...;
2019
iip = isinplace(prob)
2120
p = prob.p
2221

23-
if prob.u0 isa Number
24-
resid = [NonlinearSolve.evaluate_f(prob, first(u0))]
25-
else
26-
resid = NonlinearSolve.evaluate_f(prob, u0)
27-
end
28-
2922
if !iip && prob.u0 isa Number
3023
m! = (du, u) -> (du .= prob.f(first(u), p) .+ first(u))
3124
elseif !iip
@@ -46,4 +39,4 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::SpeedMappingJL, args...;
4639
stats = SciMLBase.NLStats(sol.maps, 0, 0, 0, sol.maps), original = sol)
4740
end
4841

49-
end
42+
end

src/extension_algs.jl

Lines changed: 79 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ for solving `NonlinearLeastSquaresProblem`.
1919
2020
This algorithm is only available if `LeastSquaresOptim.jl` is installed.
2121
"""
22-
struct LeastSquaresOptimJL{alg, linsolve} <: AbstractNonlinearAlgorithm
22+
struct LeastSquaresOptimJL{alg, linsolve} <: AbstractNonlinearSolveAlgorithm
2323
autodiff::Symbol
2424
end
2525

@@ -58,7 +58,7 @@ for solving `NonlinearLeastSquaresProblem`.
5858
5959
This algorithm is only available if `FastLevenbergMarquardt.jl` is installed.
6060
"""
61-
@concrete struct FastLevenbergMarquardtJL{linsolve} <: AbstractNonlinearAlgorithm
61+
@concrete struct FastLevenbergMarquardtJL{linsolve} <: AbstractNonlinearSolveAlgorithm
6262
autodiff
6363
factor
6464
factoraccept
@@ -128,7 +128,7 @@ then the following methods are allowed:
128128
The default choice of `:auto` selects `:hybr` for NonlinearProblem and `:lm` for
129129
NonlinearLeastSquaresProblem.
130130
"""
131-
struct CMINPACK <: AbstractNonlinearAlgorithm
131+
struct CMINPACK <: AbstractNonlinearSolveAlgorithm
132132
show_trace::Bool
133133
tracing::Bool
134134
method::Symbol
@@ -181,7 +181,7 @@ Choices for methods in `NLsolveJL`:
181181
these arguments, consult the
182182
[NLsolve.jl documentation](https://github.com/JuliaNLSolvers/NLsolve.jl).
183183
"""
184-
@concrete struct NLsolveJL <: AbstractNonlinearAlgorithm
184+
@concrete struct NLsolveJL <: AbstractNonlinearSolveAlgorithm
185185
method::Symbol
186186
autodiff::Symbol
187187
store_trace::Bool
@@ -232,7 +232,7 @@ Fixed Point Problems. We allow using this algorithm to solve root finding proble
232232
- N. Lepage-Saucier, Alternating cyclic extrapolation methods for optimization algorithms,
233233
arXiv:2104.04974 (2021). https://arxiv.org/abs/2104.04974.
234234
"""
235-
@concrete struct SpeedMappingJL <: AbstractNonlinearAlgorithm
235+
@concrete struct SpeedMappingJL <: AbstractNonlinearSolveAlgorithm
236236
σ_min
237237
stabilize::Bool
238238
check_obj::Bool
@@ -248,3 +248,77 @@ function SpeedMappingJL(; σ_min = 0.0, stabilize::Bool = false, check_obj::Bool
248248

249249
return SpeedMappingJL(σ_min, stabilize, check_obj, orders, time_limit)
250250
end
251+
252+
"""
253+
FixedPointAccelerationJL(; algorithm = :Anderson, m = missing,
254+
condition_number_threshold = missing, extrapolation_period = missing,
255+
replace_invalids = :NoAction)
256+
257+
Wrapper over [FixedPointAcceleration.jl](https://s-baumann.github.io/FixedPointAcceleration.jl/)
258+
for solving Fixed Point Problems. We allow using this algorithm to solve root finding
259+
problems as well.
260+
261+
## Arguments:
262+
263+
- `algorithm`: The algorithm to use. Can be `:Anderson`, `:MPE`, `:RRE`, `:VEA`, `:SEA`,
264+
`:Simple`, `:Aitken` or `:Newton`.
265+
- `m`: The number of previous iterates to use for the extrapolation. Only valid for
266+
`:Anderson`.
267+
- `condition_number_threshold`: The condition number threshold for Least Squares Problem.
268+
Only valid for `:Anderson`.
269+
- `extrapolation_period`: The number of iterates between extrapolations. Only valid for
270+
`:MPE`, `:RRE`, `:VEA` and `:SEA`. Defaults to `7` for `:MPE` & `:RRE`, and `6` for
271+
`:SEA` and `:VEA`. For `:SEA` and `:VEA`, this must be a multiple of `2`.
272+
- `replace_invalids`: The method to use for replacing invalid iterates. Can be
273+
`:ReplaceInvalids`, `:ReplaceVector` or `:NoAction`.
274+
"""
275+
@concrete struct FixedPointAccelerationJL <: AbstractNonlinearSolveAlgorithm
276+
algorithm::Symbol
277+
extrapolation_period::Int
278+
replace_invalids::Symbol
279+
dampening
280+
m::Int
281+
condition_number_threshold
282+
end
283+
284+
function FixedPointAccelerationJL(; algorithm = :Anderson, m = missing,
285+
condition_number_threshold = missing, extrapolation_period = missing,
286+
replace_invalids = :NoAction, dampening = 1.0)
287+
if Base.get_extension(@__MODULE__, :NonlinearSolveFixedPointAccelerationExt) === nothing
288+
error("FixedPointAccelerationJL requires FixedPointAcceleration.jl to be loaded")
289+
end
290+
291+
@assert algorithm in (:Anderson, :MPE, :RRE, :VEA, :SEA, :Simple, :Aitken, :Newton)
292+
@assert replace_invalids in (:ReplaceInvalids, :ReplaceVector, :NoAction)
293+
294+
if algorithm !== :Anderson
295+
if condition_number_threshold !== missing
296+
error("`condition_number_threshold` is only valid for Anderson acceleration")
297+
end
298+
if m !== missing
299+
error("`m` is only valid for Anderson acceleration")
300+
end
301+
end
302+
condition_number_threshold === missing && (condition_number_threshold = 1e3)
303+
m === missing && (m = 10)
304+
305+
if algorithm !== :MPE && algorithm !== :RRE && algorithm !== :VEA && algorithm !== :SEA
306+
if extrapolation_period !== missing
307+
error("`extrapolation_period` is only valid for MPE, RRE, VEA and SEA")
308+
end
309+
end
310+
if extrapolation_period === missing
311+
if algorithm === :SEA || algorithm === :VEA
312+
extrapolation_period = 6
313+
else
314+
extrapolation_period = 7
315+
end
316+
else
317+
if (algorithm === :SEA || algorithm === :VEA) && extrapolation_period % 2 != 0
318+
error("`extrapolation_period` must be multiples of 2 for SEA and VEA")
319+
end
320+
end
321+
322+
return FixedPointAccelerationJL(algorithm, extrapolation_period, replace_invalids,
323+
dampening, m, condition_number_threshold)
324+
end

test/fixed_point_acceleration.jl

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
using NonlinearSolve, FixedPointAcceleration, LinearAlgebra, Test
2+
3+
# Simple Scalar Problem
4+
@testset "Simple Scalar Problem" begin
5+
f1(x, p) = cos(x) - x
6+
prob = NonlinearProblem(f1, 1.1)
7+
8+
for alg in (:Anderson, :MPE, :RRE, :VEA, :SEA, :Simple, :Aitken, :Newton)
9+
@test abs(solve(prob, FixedPointAccelerationJL()).resid) 1e-10
10+
end
11+
end
12+
13+
# Simple Vector Problem
14+
@testset "Simple Vector Problem" begin
15+
f2(x, p) = cos.(x) .- x
16+
prob = NonlinearProblem(f2, [1.1, 1.1])
17+
18+
for alg in (:Anderson, :MPE, :RRE, :VEA, :SEA, :Simple, :Aitken, :Newton)
19+
@test maximum(abs.(solve(prob, FixedPointAccelerationJL()).resid)) 1e-10
20+
end
21+
end
22+
23+
# Fixed Point for Power Method
24+
# Taken from https://github.com/NicolasL-S/SpeedMapping.jl/blob/95951db8f8a4457093090e18802ad382db1c76da/test/runtests.jl
25+
@testset "Power Method" begin
26+
C = [1 2 3; 4 5 6; 7 8 9]
27+
A = C + C'
28+
B = Hermitian(ones(10) * ones(10)' .* im + Diagonal(1:10))
29+
30+
function power_method!(du, u, A)
31+
mul!(du, A, u)
32+
du ./= norm(du, Inf)
33+
du .-= u # Convert to a root finding problem
34+
return nothing
35+
end
36+
37+
prob = NonlinearProblem(power_method!, ones(3), A)
38+
39+
for alg in (:Anderson, :MPE, :RRE, :VEA, :SEA, :Simple, :Aitken, :Newton)
40+
sol = solve(prob, FixedPointAccelerationJL(; algorithm = alg))
41+
if SciMLBase.successful_retcode(sol)
42+
@test sol.u' * A[:, 3] 32.916472867168096
43+
else
44+
@warn "Power Method failed for FixedPointAccelerationJL(; algorithm = $alg)"
45+
@test_broken sol.u' * A[:, 3] 32.916472867168096
46+
end
47+
end
48+
49+
# Non vector inputs
50+
function power_method_nonvec!(du, u, A)
51+
mul!(vec(du), A, vec(u))
52+
du ./= norm(du, Inf)
53+
du .-= u # Convert to a root finding problem
54+
return nothing
55+
end
56+
57+
prob = NonlinearProblem(power_method_nonvec!, ones(1, 3, 1), A)
58+
59+
for alg in (:Anderson, :MPE, :RRE, :VEA, :SEA, :Simple, :Aitken, :Newton)
60+
sol = solve(prob, FixedPointAccelerationJL(; algorithm = alg))
61+
if SciMLBase.successful_retcode(sol)
62+
@test sol.u' * A[:, 3] 32.916472867168096
63+
else
64+
@warn "Power Method failed for FixedPointAccelerationJL(; algorithm = $alg)"
65+
@test_broken sol.u' * A[:, 3] 32.916472867168096
66+
end
67+
end
68+
end

0 commit comments

Comments
 (0)