Skip to content

Commit fa16796

Browse files
Merge pull request #333 from ErikQQY/qqy/siam_ext
Add SIAMFANLEquations wrapper
2 parents 2edb478 + e2de813 commit fa16796

10 files changed

+351
-7
lines changed

Project.toml

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ FixedPointAcceleration = "817d07cb-a79a-5c30-9a31-890123675176"
3636
LeastSquaresOptim = "0fc2ff8b-aaa3-5acd-a817-1944a5e08891"
3737
MINPACK = "4854310b-de5a-5eb6-a2a5-c1dee2bd17f9"
3838
NLsolve = "2774e3e8-f4cf-5e23-947b-6d7e65073b56"
39+
SIAMFANLEquations = "084e46ad-d928-497d-ad5e-07fa361a48c4"
3940
SpeedMapping = "f1835b91-879b-4a3f-a438-e4baacf14412"
4041
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
4142
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
@@ -47,6 +48,7 @@ NonlinearSolveFixedPointAccelerationExt = "FixedPointAcceleration"
4748
NonlinearSolveLeastSquaresOptimExt = "LeastSquaresOptim"
4849
NonlinearSolveMINPACKExt = "MINPACK"
4950
NonlinearSolveNLsolveExt = "NLsolve"
51+
NonlinearSolveSIAMFANLEquationsExt = "SIAMFANLEquations"
5052
NonlinearSolveSpeedMappingExt = "SpeedMapping"
5153
NonlinearSolveSymbolicsExt = "Symbolics"
5254
NonlinearSolveZygoteExt = "Zygote"
@@ -55,8 +57,8 @@ NonlinearSolveZygoteExt = "Zygote"
5557
ADTypes = "0.2.5"
5658
Aqua = "0.8"
5759
ArrayInterface = "7.7"
58-
BandedMatrices = "1.3"
59-
BenchmarkTools = "1"
60+
BandedMatrices = "1.4"
61+
BenchmarkTools = "1.4"
6062
ConcreteStructs = "0.2"
6163
DiffEqBase = "6.144"
6264
EnumX = "1"
@@ -86,6 +88,7 @@ Reexport = "1.2"
8688
SafeTestsets = "0.1"
8789
SciMLBase = "2.11"
8890
SciMLOperators = "0.3.7"
91+
SIAMFANLEquations = "1.0.1"
8992
SimpleNonlinearSolve = "1.0.2"
9093
SparseArrays = "1.9"
9194
SparseDiffTools = "2.14"
@@ -118,6 +121,7 @@ OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
118121
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
119122
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
120123
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
124+
SIAMFANLEquations = "084e46ad-d928-497d-ad5e-07fa361a48c4"
121125
SparseDiffTools = "47a9eef4-7e08-11e9-0b38-333d64bd3804"
122126
SpeedMapping = "f1835b91-879b-4a3f-a438-e4baacf14412"
123127
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
@@ -127,4 +131,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
127131
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
128132

129133
[targets]
130-
test = ["Aqua", "Enzyme", "BenchmarkTools", "SafeTestsets", "Pkg", "Test", "ForwardDiff", "StaticArrays", "Symbolics", "LinearSolve", "Random", "LinearAlgebra", "Zygote", "SparseDiffTools", "NonlinearProblemLibrary", "LeastSquaresOptim", "FastLevenbergMarquardt", "NaNMath", "BandedMatrices", "DiffEqBase", "StableRNGs", "MINPACK", "NLsolve", "OrdinaryDiffEq", "SpeedMapping", "FixedPointAcceleration"]
134+
test = ["Aqua", "Enzyme", "BenchmarkTools", "SafeTestsets", "Pkg", "Test", "ForwardDiff", "StaticArrays", "Symbolics", "LinearSolve", "Random", "LinearAlgebra", "Zygote", "SparseDiffTools", "NonlinearProblemLibrary", "LeastSquaresOptim", "FastLevenbergMarquardt", "NaNMath", "BandedMatrices", "DiffEqBase", "StableRNGs", "MINPACK", "NLsolve", "OrdinaryDiffEq", "SpeedMapping", "FixedPointAcceleration", "SIAMFANLEquations"]

docs/pages.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ pages = ["index.md",
3131
"api/leastsquaresoptim.md",
3232
"api/fastlevenbergmarquardt.md",
3333
"api/speedmapping.md",
34-
"api/fixedpointacceleration.md"],
34+
"api/fixedpointacceleration.md",
35+
"api/siamfanlequations.md"],
3536
"Release Notes" => "release_notes.md",
3637
]

docs/src/api/fastlevenbergmarquardt.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# FastLevenbergMarquardt.jl
22

3-
This is a extension for importing solvers from FastLevenbergMarquardt.jl into the SciML
3+
This is an extension for importing solvers from FastLevenbergMarquardt.jl into the SciML
44
interface. Note that these solvers do not come by default, and thus one needs to install
55
the package before using these solvers:
66

docs/src/api/leastsquaresoptim.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# LeastSquaresOptim.jl
22

3-
This is a extension for importing solvers from LeastSquaresOptim.jl into the SciML
3+
This is an extension for importing solvers from LeastSquaresOptim.jl into the SciML
44
interface. Note that these solvers do not come by default, and thus one needs to install
55
the package before using these solvers:
66

docs/src/api/siamfanlequations.md

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
# SIAMFANLEquations.jl
2+
3+
This is an extension for importing solvers from [SIAMFANLEquations.jl](https://github.com/ctkelley/SIAMFANLEquations.jl) into the SciML
4+
interface. Note that these solvers do not come by default, and thus one needs to install
5+
the package before using these solvers:
6+
7+
```julia
8+
using Pkg
9+
Pkg.add("SIAMFANLEquations")
10+
using SIAMFANLEquations, NonlinearSolve
11+
```
12+
13+
## Solver API
14+
15+
```@docs
16+
SIAMFANLEquationsJL
17+
```
Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
module NonlinearSolveSIAMFANLEquationsExt
2+
3+
using NonlinearSolve, SciMLBase
4+
using SIAMFANLEquations
5+
import ConcreteStructs: @concrete
6+
import UnPack: @unpack
7+
8+
function SciMLBase.__solve(prob::NonlinearProblem, alg::SIAMFANLEquationsJL, args...; abstol = nothing,
9+
reltol = nothing, alias_u0::Bool = false, maxiters = 1000, termination_condition = nothing, kwargs...)
10+
@assert (termination_condition === nothing) || (termination_condition isa AbsNormTerminationMode) "SIAMFANLEquationsJL does not support termination conditions!"
11+
12+
@unpack method, show_trace, delta, linsolve = alg
13+
14+
iip = SciMLBase.isinplace(prob)
15+
T = eltype(prob.u0)
16+
17+
atol = abstol === nothing ? real(oneunit(T)) * (eps(real(one(T))))^(4 // 5) : abstol
18+
rtol = reltol === nothing ? real(oneunit(T)) * (eps(real(one(T))))^(4 // 5) : reltol
19+
20+
if prob.u0 isa Number
21+
f! = if iip
22+
function (u)
23+
du = similar(u)
24+
prob.f(du, u, prob.p)
25+
return du
26+
end
27+
else
28+
u -> prob.f(u, prob.p)
29+
end
30+
31+
if method == :newton
32+
sol = nsolsc(f!, prob.u0; maxit = maxiters, atol = atol, rtol = rtol, printerr = show_trace)
33+
elseif method == :pseudotransient
34+
sol = ptcsolsc(f!, prob.u0; delta0 = delta, maxit = maxiters, atol = atol, rtol=rtol, printerr = show_trace)
35+
elseif method == :secant
36+
sol = secant(f!, prob.u0; maxit = maxiters, atol = atol, rtol = rtol, printerr = show_trace)
37+
end
38+
39+
if sol.errcode == 0
40+
retcode = ReturnCode.Success
41+
elseif sol.errcode == 10
42+
retcode = ReturnCode.MaxIters
43+
elseif sol.errcode == 1
44+
retcode = ReturnCode.Failure
45+
elseif sol.errcode == -1
46+
retcode = ReturnCode.Default
47+
end
48+
stats = method == :pseudotransient ? nothing : (SciMLBase.NLStats(sum(sol.stats.ifun), sum(sol.stats.ijac), 0, 0, sum(sol.stats.iarm)))
49+
return SciMLBase.build_solution(prob, alg, sol.solution, sol.history; retcode, stats, original = sol)
50+
else
51+
u = NonlinearSolve.__maybe_unaliased(prob.u0, alias_u0)
52+
end
53+
54+
if iip
55+
f! = function (du, u)
56+
prob.f(du, u, prob.p)
57+
return du
58+
end
59+
else
60+
f! = function (du, u)
61+
du .= prob.f(u, prob.p)
62+
return du
63+
end
64+
end
65+
66+
# Allocate ahead for function
67+
N = length(u)
68+
FS = zeros(T, N)
69+
70+
# Jacobian free Newton Krylov
71+
if linsolve !== nothing
72+
# Allocate ahead for Krylov basis
73+
JVS = linsolve == :gmres ? zeros(T, N, 3) : zeros(T, N)
74+
# `linsolve` as a Symbol to keep unified interface with other EXTs, SIAMFANLEquations directly use String to choose between different linear solvers
75+
linsolve_alg = String(linsolve)
76+
77+
if method == :newton
78+
sol = nsoli(f!, u, FS, JVS; lsolver = linsolve_alg, maxit = maxiters, atol = atol, rtol = rtol, printerr = show_trace)
79+
elseif method == :pseudotransient
80+
sol = ptcsoli(f!, u, FS, JVS; lsolver = linsolve_alg, maxit = maxiters, atol = atol, rtol = rtol, printerr = show_trace)
81+
end
82+
83+
if sol.errcode == 0
84+
retcode = ReturnCode.Success
85+
elseif sol.errcode == 10
86+
retcode = ReturnCode.MaxIters
87+
elseif sol.errcode == 1
88+
retcode = ReturnCode.Failure
89+
elseif sol.errcode == -1
90+
retcode = ReturnCode.Default
91+
end
92+
stats = method == :pseudotransient ? nothing : (SciMLBase.NLStats(sum(sol.stats.ifun), sum(sol.stats.ijac), 0, 0, sum(sol.stats.iarm)))
93+
return SciMLBase.build_solution(prob, alg, sol.solution, sol.history; retcode, stats, original = sol)
94+
end
95+
96+
# Allocate ahead for Jacobian
97+
FPS = zeros(T, N, N)
98+
if prob.f.jac === nothing
99+
# Use the built-in Jacobian machinery
100+
if method == :newton
101+
sol = nsol(f!, u, FS, FPS;
102+
sham=1, atol = atol, rtol = rtol, maxit = maxiters,
103+
printerr = show_trace)
104+
elseif method == :pseudotransient
105+
sol = ptcsol(f!, u, FS, FPS;
106+
atol = atol, rtol = rtol, maxit = maxiters,
107+
delta0 = delta, printerr = show_trace)
108+
end
109+
else
110+
AJ!(J, u, x) = prob.f.jac(J, x, prob.p)
111+
if method == :newton
112+
sol = nsol(f!, u, FS, FPS, AJ!;
113+
sham=1, atol = atol, rtol = rtol, maxit = maxiters,
114+
printerr = show_trace)
115+
elseif method == :pseudotransient
116+
sol = ptcsol(f!, u, FS, FPS, AJ!;
117+
atol = atol, rtol = rtol, maxit = maxiters,
118+
delta0 = delta, printerr = show_trace)
119+
end
120+
end
121+
122+
if sol.errcode == 0
123+
retcode = ReturnCode.Success
124+
elseif sol.errcode == 10
125+
retcode = ReturnCode.MaxIters
126+
elseif sol.errcode == 1
127+
retcode = ReturnCode.Failure
128+
elseif sol.errcode == -1
129+
retcode = ReturnCode.Default
130+
end
131+
132+
# pseudo transient continuation has a fixed cost per iteration, iteration statistics are not interesting here.
133+
stats = method == :pseudotransient ? nothing : (SciMLBase.NLStats(sum(sol.stats.ifun), sum(sol.stats.ijac), 0, 0, sum(sol.stats.iarm)))
134+
return SciMLBase.build_solution(prob, alg, sol.solution, sol.history; retcode, stats, original = sol)
135+
end
136+
137+
end

src/NonlinearSolve.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,7 @@ export RadiusUpdateSchemes
237237
export NewtonRaphson, TrustRegion, LevenbergMarquardt, DFSane, GaussNewton, PseudoTransient,
238238
Broyden, Klement, LimitedMemoryBroyden
239239
export LeastSquaresOptimJL,
240-
FastLevenbergMarquardtJL, CMINPACK, NLsolveJL, FixedPointAccelerationJL, SpeedMappingJL
240+
FastLevenbergMarquardtJL, CMINPACK, NLsolveJL, FixedPointAccelerationJL, SpeedMappingJL, SIAMFANLEquationsJL
241241
export NonlinearSolvePolyAlgorithm,
242242
RobustMultiNewton, FastShortcutNonlinearPolyalg, FastShortcutNLLSPolyalg
243243

src/extension_algs.jl

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,7 @@ function NLsolveJL(; method = :trust_region, autodiff = :central, store_trace =
208208
end
209209

210210
"""
211+
211212
SpeedMappingJL(; σ_min = 0.0, stabilize::Bool = false, check_obj::Bool = false,
212213
orders::Vector{Int} = [3, 3, 2], time_limit::Real = 1000)
213214
@@ -322,3 +323,34 @@ function FixedPointAccelerationJL(; algorithm = :Anderson, m = missing,
322323
return FixedPointAccelerationJL(algorithm, extrapolation_period, replace_invalids,
323324
dampening, m, condition_number_threshold)
324325
end
326+
327+
"""
328+
329+
SIAMFANLEquationsJL(; method = :newton, autodiff = :central, show_trace = false, delta = 1e-3, linsolve = nothing)
330+
331+
### Keyword Arguments
332+
333+
- `method`: the choice of method for solving the nonlinear system.
334+
- `show_trace`: whether to show the trace.
335+
- `delta`: initial pseudo time step, default is 1e-3.
336+
- `linsolve` : JFNK linear solvers, choices are `gmres` and `bicgstab`.
337+
338+
### Submethod Choice
339+
340+
- `:newton`: Classical Newton method.
341+
- `:pseudotransient`: Pseudo transient method.
342+
- `:secant`: Secant method for scalar equations.
343+
"""
344+
@concrete struct SIAMFANLEquationsJL <: AbstractNonlinearAlgorithm
345+
method::Symbol
346+
show_trace::Bool
347+
delta
348+
linsolve::Union{Symbol, Nothing}
349+
end
350+
351+
function SIAMFANLEquationsJL(; method = :newton, show_trace = false, delta = 1e-3, linsolve = nothing)
352+
if Base.get_extension(@__MODULE__, :NonlinearSolveSIAMFANLEquationsExt) === nothing
353+
error("SIAMFANLEquationsJL requires SIAMFANLEquations.jl to be loaded")
354+
end
355+
return SIAMFANLEquationsJL(method, show_trace, delta, linsolve)
356+
end

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ end
2424
if GROUP == "All" || GROUP == "Wrappers"
2525
@time @safetestset "MINPACK" include("minpack.jl")
2626
@time @safetestset "NLsolve" include("nlsolve.jl")
27+
@time @safetestset "SIAMFANLEquations" include("siamfanlequations.jl")
2728
@time @safetestset "SpeedMapping" include("speedmapping.jl")
2829
@time @safetestset "FixedPointAcceleration" include("fixed_point_acceleration.jl")
2930
end

0 commit comments

Comments
 (0)