Skip to content

Commit 1ada921

Browse files
Merge pull request #337 from avik-pal/ap/fixed_point
FixedPointSolvers: Recommendations and Wrappers
2 parents d723578 + 29917d8 commit 1ada921

15 files changed

+440
-17
lines changed

Project.toml

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "NonlinearSolve"
22
uuid = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
33
authors = ["SciML"]
4-
version = "3.1.2"
4+
version = "3.2.0"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
@@ -32,18 +32,22 @@ UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"
3232
[weakdeps]
3333
BandedMatrices = "aae01518-5342-5314-be14-df237901396f"
3434
FastLevenbergMarquardt = "7a0df574-e128-4d35-8cbd-3d84502bf7ce"
35+
FixedPointAcceleration = "817d07cb-a79a-5c30-9a31-890123675176"
3536
LeastSquaresOptim = "0fc2ff8b-aaa3-5acd-a817-1944a5e08891"
3637
MINPACK = "4854310b-de5a-5eb6-a2a5-c1dee2bd17f9"
3738
NLsolve = "2774e3e8-f4cf-5e23-947b-6d7e65073b56"
39+
SpeedMapping = "f1835b91-879b-4a3f-a438-e4baacf14412"
3840
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
3941
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
4042

4143
[extensions]
4244
NonlinearSolveBandedMatricesExt = "BandedMatrices"
4345
NonlinearSolveFastLevenbergMarquardtExt = "FastLevenbergMarquardt"
46+
NonlinearSolveFixedPointAccelerationExt = "FixedPointAcceleration"
4447
NonlinearSolveLeastSquaresOptimExt = "LeastSquaresOptim"
4548
NonlinearSolveMINPACKExt = "MINPACK"
4649
NonlinearSolveNLsolveExt = "NLsolve"
50+
NonlinearSolveSpeedMappingExt = "SpeedMapping"
4751
NonlinearSolveSymbolicsExt = "Symbolics"
4852
NonlinearSolveZygoteExt = "Zygote"
4953

@@ -60,6 +64,7 @@ Enzyme = "0.11.11"
6064
FastBroadcast = "0.2.8"
6165
FastLevenbergMarquardt = "0.1"
6266
FiniteDiff = "2.21"
67+
FixedPointAcceleration = "0.3"
6368
ForwardDiff = "0.10.36"
6469
LazyArrays = "1.8.2"
6570
LeastSquaresOptim = "0.8.5"
@@ -84,6 +89,7 @@ SciMLOperators = "0.3.7"
8489
SimpleNonlinearSolve = "1.0.2"
8590
SparseArrays = "<0.0.1, 1"
8691
SparseDiffTools = "2.14"
92+
SpeedMapping = "0.3"
8793
StableRNGs = "1"
8894
StaticArrays = "1.7"
8995
Symbolics = "5.13"
@@ -99,6 +105,7 @@ BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
99105
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
100106
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
101107
FastLevenbergMarquardt = "7a0df574-e128-4d35-8cbd-3d84502bf7ce"
108+
FixedPointAcceleration = "817d07cb-a79a-5c30-9a31-890123675176"
102109
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
103110
LeastSquaresOptim = "0fc2ff8b-aaa3-5acd-a817-1944a5e08891"
104111
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
@@ -112,11 +119,12 @@ Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
112119
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
113120
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
114121
SparseDiffTools = "47a9eef4-7e08-11e9-0b38-333d64bd3804"
122+
SpeedMapping = "f1835b91-879b-4a3f-a438-e4baacf14412"
115123
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
116124
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
117125
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
118126
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
119127
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
120128

121129
[targets]
122-
test = ["Aqua", "Enzyme", "BenchmarkTools", "SafeTestsets", "Pkg", "Test", "ForwardDiff", "StaticArrays", "Symbolics", "LinearSolve", "Random", "LinearAlgebra", "Zygote", "SparseDiffTools", "NonlinearProblemLibrary", "LeastSquaresOptim", "FastLevenbergMarquardt", "NaNMath", "BandedMatrices", "DiffEqBase", "StableRNGs", "MINPACK", "NLsolve", "OrdinaryDiffEq"]
130+
test = ["Aqua", "Enzyme", "BenchmarkTools", "SafeTestsets", "Pkg", "Test", "ForwardDiff", "StaticArrays", "Symbolics", "LinearSolve", "Random", "LinearAlgebra", "Zygote", "SparseDiffTools", "NonlinearProblemLibrary", "LeastSquaresOptim", "FastLevenbergMarquardt", "NaNMath", "BandedMatrices", "DiffEqBase", "StableRNGs", "MINPACK", "NLsolve", "OrdinaryDiffEq", "SpeedMapping", "FixedPointAcceleration"]

docs/pages.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ pages = ["index.md",
2020
"solvers/BracketingSolvers.md",
2121
"solvers/SteadyStateSolvers.md",
2222
"solvers/NonlinearLeastSquaresSolvers.md",
23+
"solvers/FixedPointSolvers.md",
2324
"solvers/LineSearch.md"],
2425
"Detailed Solver APIs" => Any["api/nonlinearsolve.md",
2526
"api/simplenonlinearsolve.md",
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
# FixedPointAcceleration.jl
2+
3+
This is a extension for importing solvers from FixedPointAcceleration.jl into the SciML
4+
interface. Note that these solvers do not come by default, and thus one needs to install
5+
the package before using these solvers:
6+
7+
```julia
8+
using Pkg
9+
Pkg.add("FixedPointAcceleration")
10+
using FixedPointAcceleration, NonlinearSolve
11+
```
12+
13+
## Solver API
14+
15+
```@docs
16+
FixedPointAccelerationJL
17+
```

docs/src/api/speedmapping.md

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
# SpeedMapping.jl
2+
3+
This is a extension for importing solvers from SpeedMapping.jl into the SciML
4+
interface. Note that these solvers do not come by default, and thus one needs to install
5+
the package before using these solvers:
6+
7+
```julia
8+
using Pkg
9+
Pkg.add("SpeedMapping")
10+
using SpeedMapping, NonlinearSolve
11+
```
12+
13+
## Solver API
14+
15+
```@docs
16+
SpeedMappingJL
17+
```

docs/src/solvers/FixedPointSolvers.md

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
# Fixed Point Solvers
2+
3+
Currently we don't have an API to directly specify Fixed Point Solvers. However, a Fixed
4+
Point Problem can be trivially converted to a Root Finding Problem. Say we want to solve:
5+
6+
```math
7+
f(u) = u
8+
```
9+
10+
This can be written as:
11+
12+
```math
13+
g(u) = f(u) - u = 0
14+
```
15+
16+
``g(u) = 0`` is a root finding problem. Note that we can use any root finding
17+
algorithm to solve this problem. However, this is often not the most efficient way to
18+
solve a fixed point problem. We provide a few algorithms available via extensions that
19+
are more efficient for fixed point problems.
20+
21+
Note that even if you use one of the Fixed Point Solvers mentioned here, you must still
22+
use the `NonlinearProblem` API to specify the problem, i.e., ``g(u) = 0``.
23+
24+
## Recommended Methods
25+
26+
Using [native NonlinearSolve.jl methods](@ref nonlinearsystemsolvers) is the recommended
27+
approach. For systems where constructing Jacobian Matrices are expensive, we recommend
28+
using a Krylov Method with one of those solvers.
29+
30+
## Full List of Methods
31+
32+
We are only listing the methods that natively solve fixed point problems.
33+
34+
### SpeedMapping.jl
35+
36+
- `SpeedMappingJL()`: accelerates the convergence of a mapping to a fixed point by the
37+
Alternating cyclic extrapolation algorithm (ACX).
38+
39+
### FixedPointAcceleration.jl
40+
41+
- `FixedPointAccelerationJL()`: accelerates the convergence of a mapping to a fixed point
42+
by the Anderson acceleration algorithm and a few other methods.

docs/src/solvers/NonlinearSystemSolvers.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# [Nonlinear System Solvers](@id nonlinearsystemsolvers)
22

3-
`solve(prob::NonlinearProblem,alg;kwargs)`
3+
`solve(prob::NonlinearProblem, alg; kwargs)`
44

55
Solves for ``f(u)=0`` in the problem defined by `prob` using the algorithm
66
`alg`. If no algorithm is given, a default algorithm will be chosen.
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
module NonlinearSolveFixedPointAccelerationExt
2+
3+
using NonlinearSolve, FixedPointAcceleration, DiffEqBase, SciMLBase
4+
5+
function SciMLBase.__solve(prob::NonlinearProblem, alg::FixedPointAccelerationJL, args...;
6+
abstol = nothing, maxiters = 1000, alias_u0::Bool = false,
7+
show_trace::Val{PrintReports} = Val(false), termination_condition = nothing,
8+
kwargs...) where {PrintReports}
9+
@assert (termination_condition ===
10+
nothing)||(termination_condition isa AbsNormTerminationMode) "SpeedMappingJL does not support termination conditions!"
11+
12+
u0 = NonlinearSolve.__maybe_unaliased(prob.u0, alias_u0)
13+
u_size = size(u0)
14+
T = eltype(u0)
15+
iip = isinplace(prob)
16+
p = prob.p
17+
18+
if !iip && prob.u0 isa Number
19+
# FixedPointAcceleration makes the scalar problem into a vector problem
20+
f = (u) -> [prob.f(u[1], p) .+ u[1]]
21+
elseif !iip && prob.u0 isa AbstractVector
22+
f = (u) -> (prob.f(u, p) .+ u)
23+
elseif !iip && prob.u0 isa AbstractArray
24+
f = (u) -> vec(prob.f(reshape(u, u_size), p) .+ u)
25+
elseif iip && prob.u0 isa AbstractVector
26+
du = similar(u0)
27+
f = (u) -> (prob.f(du, u, p); du .+ u)
28+
else
29+
du = similar(u0)
30+
f = (u) -> (prob.f(du, reshape(u, u_size), p); vec(du) .+ u)
31+
end
32+
33+
tol = abstol === nothing ? real(oneunit(T)) * (eps(real(one(T))))^(4 // 5) : abstol
34+
35+
sol = fixed_point(f, NonlinearSolve._vec(u0); Algorithm = alg.algorithm,
36+
ConvergenceMetricThreshold = tol, MaxIter = maxiters, MaxM = alg.m,
37+
ExtrapolationPeriod = alg.extrapolation_period, Dampening = alg.dampening,
38+
PrintReports, ReplaceInvalids = alg.replace_invalids,
39+
ConditionNumberThreshold = alg.condition_number_threshold, quiet_errors = true)
40+
41+
res = prob.u0 isa Number ? first(sol.FixedPoint_) : sol.FixedPoint_
42+
if res === missing
43+
resid = NonlinearSolve.evaluate_f(prob, u0)
44+
res = u0
45+
converged = false
46+
else
47+
resid = NonlinearSolve.evaluate_f(prob, res)
48+
converged = maximum(abs, resid) tol
49+
end
50+
return SciMLBase.build_solution(prob, alg, res, resid;
51+
retcode = converged ? ReturnCode.Success : ReturnCode.Failure,
52+
stats = SciMLBase.NLStats(sol.Iterations_, 0, 0, 0, sol.Iterations_),
53+
original = sol)
54+
end
55+
56+
end

ext/NonlinearSolveMINPACKExt.jl

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ using MINPACK
55

66
function SciMLBase.__solve(prob::Union{NonlinearProblem{uType, iip},
77
NonlinearLeastSquaresProblem{uType, iip}}, alg::CMINPACK, args...;
8-
abstol = 1e-6, maxiters = 100000, alias_u0::Bool = false,
8+
abstol = nothing, maxiters = 100000, alias_u0::Bool = false,
99
termination_condition = nothing, kwargs...) where {uType, iip}
1010
@assert (termination_condition ===
1111
nothing)||(termination_condition isa AbsNormTerminationMode) "CMINPACK does not support termination conditions!"
@@ -16,6 +16,7 @@ function SciMLBase.__solve(prob::Union{NonlinearProblem{uType, iip},
1616
u0 = NonlinearSolve.__maybe_unaliased(prob.u0, alias_u0)
1717
end
1818

19+
T = eltype(u0)
1920
sizeu = size(prob.u0)
2021
p = prob.p
2122

@@ -25,11 +26,11 @@ function SciMLBase.__solve(prob::Union{NonlinearProblem{uType, iip},
2526

2627
if !iip && prob.u0 isa Number
2728
f! = (du, u) -> (du .= prob.f(first(u), p); Cint(0))
28-
elseif !iip && prob.u0 isa Vector{Float64}
29+
elseif !iip && prob.u0 isa AbstractVector
2930
f! = (du, u) -> (du .= prob.f(u, p); Cint(0))
3031
elseif !iip && prob.u0 isa AbstractArray
3132
f! = (du, u) -> (du .= vec(prob.f(reshape(u, sizeu), p)); Cint(0))
32-
elseif prob.u0 isa Vector{Float64}
33+
elseif prob.u0 isa AbstractVector
3334
f! = (du, u) -> prob.f(du, u, p)
3435
else # Then it's an in-place function on an abstract array
3536
f! = (du, u) -> (prob.f(reshape(du, sizeu), reshape(u, sizeu), p); du = vec(du); 0)
@@ -43,14 +44,16 @@ function SciMLBase.__solve(prob::Union{NonlinearProblem{uType, iip},
4344
method = ifelse(alg.method === :auto,
4445
ifelse(prob isa NonlinearLeastSquaresProblem, :lm, :hybr), alg.method)
4546

47+
abstol = abstol === nothing ? real(oneunit(T)) * (eps(real(one(T))))^(4 // 5) : abstol
48+
4649
if SciMLBase.has_jac(prob.f)
4750
if !iip && prob.u0 isa Number
4851
g! = (du, u) -> (du .= prob.f.jac(first(u), p); Cint(0))
49-
elseif !iip && prob.u0 isa Vector{Float64}
52+
elseif !iip && prob.u0 isa AbstractVector
5053
g! = (du, u) -> (du .= prob.f.jac(u, p); Cint(0))
5154
elseif !iip && prob.u0 isa AbstractArray
5255
g! = (du, u) -> (du .= vec(prob.f.jac(reshape(u, sizeu), p)); Cint(0))
53-
elseif prob.u0 isa Vector{Float64}
56+
elseif prob.u0 isa AbstractVector
5457
g! = (du, u) -> prob.f.jac(du, u, p)
5558
else # Then it's an in-place function on an abstract array
5659
g! = function (du, u)

ext/NonlinearSolveNLsolveExt.jl

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,9 @@ module NonlinearSolveNLsolveExt
33
using NonlinearSolve, NLsolve, DiffEqBase, SciMLBase
44
import UnPack: @unpack
55

6-
function SciMLBase.__solve(prob::NonlinearProblem, alg::NLsolveJL, args...; abstol = 1e-6,
7-
maxiters = 1000, alias_u0::Bool = false, termination_condition = nothing, kwargs...)
6+
function SciMLBase.__solve(prob::NonlinearProblem, alg::NLsolveJL, args...;
7+
abstol = nothing, maxiters = 1000, alias_u0::Bool = false,
8+
termination_condition = nothing, kwargs...)
89
@assert (termination_condition ===
910
nothing)||(termination_condition isa AbsNormTerminationMode) "NLsolveJL does not support termination conditions!"
1011

@@ -14,6 +15,7 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::NLsolveJL, args...; abst
1415
u0 = NonlinearSolve.__maybe_unaliased(prob.u0, alias_u0)
1516
end
1617

18+
T = eltype(u0)
1719
iip = isinplace(prob)
1820

1921
sizeu = size(prob.u0)
@@ -25,11 +27,11 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::NLsolveJL, args...; abst
2527

2628
if !iip && prob.u0 isa Number
2729
f! = (du, u) -> (du .= prob.f(first(u), p); Cint(0))
28-
elseif !iip && prob.u0 isa Vector{Float64}
30+
elseif !iip && prob.u0 isa AbstractVector
2931
f! = (du, u) -> (du .= prob.f(u, p); Cint(0))
3032
elseif !iip && prob.u0 isa AbstractArray
3133
f! = (du, u) -> (du .= vec(prob.f(reshape(u, sizeu), p)); Cint(0))
32-
elseif prob.u0 isa Vector{Float64}
34+
elseif prob.u0 isa AbstractVector
3335
f! = (du, u) -> prob.f(du, u, p)
3436
else # Then it's an in-place function on an abstract array
3537
f! = (du, u) -> (prob.f(reshape(du, sizeu), reshape(u, sizeu), p); du = vec(du); 0)
@@ -46,11 +48,11 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::NLsolveJL, args...; abst
4648
if SciMLBase.has_jac(prob.f)
4749
if !iip && prob.u0 isa Number
4850
g! = (du, u) -> (du .= prob.f.jac(first(u), p); Cint(0))
49-
elseif !iip && prob.u0 isa Vector{Float64}
51+
elseif !iip && prob.u0 isa AbstractVector
5052
g! = (du, u) -> (du .= prob.f.jac(u, p); Cint(0))
5153
elseif !iip && prob.u0 isa AbstractArray
5254
g! = (du, u) -> (du .= vec(prob.f.jac(reshape(u, sizeu), p)); Cint(0))
53-
elseif prob.u0 isa Vector{Float64}
55+
elseif prob.u0 isa AbstractVector
5456
g! = (du, u) -> prob.f.jac(du, u, p)
5557
else # Then it's an in-place function on an abstract array
5658
g! = function (du, u)
@@ -68,6 +70,8 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::NLsolveJL, args...; abst
6870
df = OnceDifferentiable(f!, vec(u0), vec(resid); autodiff)
6971
end
7072

73+
abstol = abstol === nothing ? real(oneunit(T)) * (eps(real(one(T))))^(4 // 5) : abstol
74+
7175
original = nlsolve(df, vec(u0); ftol = abstol, iterations = maxiters, method,
7276
store_trace, extended_trace, linesearch, linsolve, factor, autoscale, m, beta,
7377
show_trace)

ext/NonlinearSolveSpeedMappingExt.jl

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
module NonlinearSolveSpeedMappingExt
2+
3+
using NonlinearSolve, SpeedMapping, DiffEqBase, SciMLBase
4+
5+
function SciMLBase.__solve(prob::NonlinearProblem, alg::SpeedMappingJL, args...;
6+
abstol = nothing, maxiters = 1000, alias_u0::Bool = false,
7+
store_trace::Val{store_info} = Val(false), termination_condition = nothing,
8+
kwargs...) where {store_info}
9+
@assert (termination_condition ===
10+
nothing)||(termination_condition isa AbsNormTerminationMode) "SpeedMappingJL does not support termination conditions!"
11+
12+
if typeof(prob.u0) <: Number
13+
u0 = [prob.u0]
14+
else
15+
u0 = NonlinearSolve.__maybe_unaliased(prob.u0, alias_u0)
16+
end
17+
18+
T = eltype(u0)
19+
iip = isinplace(prob)
20+
p = prob.p
21+
22+
if !iip && prob.u0 isa Number
23+
m! = (du, u) -> (du .= prob.f(first(u), p) .+ first(u))
24+
elseif !iip
25+
m! = (du, u) -> (du .= prob.f(u, p) .+ u)
26+
else
27+
m! = (du, u) -> (prob.f(du, u, p); du .+= u)
28+
end
29+
30+
tol = abstol === nothing ? real(oneunit(T)) * (eps(real(one(T))))^(4 // 5) : abstol
31+
32+
sol = speedmapping(u0; m!, tol, Lp = Inf, maps_limit = maxiters, alg.orders,
33+
alg.check_obj, store_info, alg.σ_min, alg.stabilize)
34+
res = prob.u0 isa Number ? first(sol.minimizer) : sol.minimizer
35+
resid = NonlinearSolve.evaluate_f(prob, sol.minimizer)
36+
37+
return SciMLBase.build_solution(prob, alg, res, resid;
38+
retcode = sol.converged ? ReturnCode.Success : ReturnCode.Failure,
39+
stats = SciMLBase.NLStats(sol.maps, 0, 0, 0, sol.maps), original = sol)
40+
end
41+
42+
end

0 commit comments

Comments
 (0)