Skip to content

Commit fc437fd

Browse files
committed
Add SIAMFANLEquations wrapper
Signed-off-by: ErikQQY <2283984853@qq.com>
1 parent aa282b7 commit fc437fd

9 files changed

+374
-4
lines changed

Project.toml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ FastLevenbergMarquardt = "7a0df574-e128-4d35-8cbd-3d84502bf7ce"
3535
LeastSquaresOptim = "0fc2ff8b-aaa3-5acd-a817-1944a5e08891"
3636
MINPACK = "4854310b-de5a-5eb6-a2a5-c1dee2bd17f9"
3737
NLsolve = "2774e3e8-f4cf-5e23-947b-6d7e65073b56"
38+
SIAMFANLEquations = "084e46ad-d928-497d-ad5e-07fa361a48c4"
3839
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
3940
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
4041

@@ -44,6 +45,7 @@ NonlinearSolveFastLevenbergMarquardtExt = "FastLevenbergMarquardt"
4445
NonlinearSolveLeastSquaresOptimExt = "LeastSquaresOptim"
4546
NonlinearSolveMINPACKExt = "MINPACK"
4647
NonlinearSolveNLsolveExt = "NLsolve"
48+
NonlinearSolveSIAMFANLEquationsExt = "SIAMFANLEquations"
4749
NonlinearSolveSymbolicsExt = "Symbolics"
4850
NonlinearSolveZygoteExt = "Zygote"
4951

@@ -80,6 +82,7 @@ Reexport = "1.2"
8082
SafeTestsets = "0.1"
8183
SciMLBase = "2.11"
8284
SciMLOperators = "0.3.7"
85+
SIAMFANLEquations = "1.0.1"
8386
SimpleNonlinearSolve = "1.0.2"
8487
SparseArrays = "<0.0.1, 1"
8588
SparseDiffTools = "2.14"
@@ -109,6 +112,7 @@ NonlinearProblemLibrary = "b7050fa9-e91f-4b37-bcee-a89a063da141"
109112
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
110113
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
111114
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
115+
SIAMFANLEquations = "084e46ad-d928-497d-ad5e-07fa361a48c4"
112116
SparseDiffTools = "47a9eef4-7e08-11e9-0b38-333d64bd3804"
113117
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
114118
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
@@ -117,4 +121,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
117121
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
118122

119123
[targets]
120-
test = ["Aqua", "Enzyme", "BenchmarkTools", "SafeTestsets", "Pkg", "Test", "ForwardDiff", "StaticArrays", "Symbolics", "LinearSolve", "Random", "LinearAlgebra", "Zygote", "SparseDiffTools", "NonlinearProblemLibrary", "LeastSquaresOptim", "FastLevenbergMarquardt", "NaNMath", "BandedMatrices", "DiffEqBase", "StableRNGs", "MINPACK", "NLsolve"]
124+
test = ["Aqua", "Enzyme", "BenchmarkTools", "SafeTestsets", "Pkg", "Test", "ForwardDiff", "StaticArrays", "Symbolics", "LinearSolve", "Random", "LinearAlgebra", "Zygote", "SparseDiffTools", "NonlinearProblemLibrary", "LeastSquaresOptim", "FastLevenbergMarquardt", "NaNMath", "BandedMatrices", "DiffEqBase", "StableRNGs", "MINPACK", "NLsolve", "SIAMFANLEquations"]

docs/src/api/fastlevenbergmarquardt.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# FastLevenbergMarquardt.jl
22

3-
This is a extension for importing solvers from FastLevenbergMarquardt.jl into the SciML
3+
This is an extension for importing solvers from FastLevenbergMarquardt.jl into the SciML
44
interface. Note that these solvers do not come by default, and thus one needs to install
55
the package before using these solvers:
66

docs/src/api/leastsquaresoptim.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# LeastSquaresOptim.jl
22

3-
This is a extension for importing solvers from LeastSquaresOptim.jl into the SciML
3+
This is an extension for importing solvers from LeastSquaresOptim.jl into the SciML
44
interface. Note that these solvers do not come by default, and thus one needs to install
55
the package before using these solvers:
66

docs/src/api/siamfanlequations.md

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
# SIAMFANLEquations.jl
2+
3+
This is an extension for importing solvers from [SIAMFANLEquations.jl](https://github.com/ctkelley/SIAMFANLEquations.jl) into the SciML
4+
interface. Note that these solvers do not come by default, and thus one needs to install
5+
the package before using these solvers:
6+
7+
```julia
8+
using Pkg
9+
Pkg.add("SIAMFANLEquations")
10+
using SIAMFANLEquations, NonlinearSolve
11+
```
12+
13+
## Solver API
14+
15+
```@docs
16+
SIAMFANLEquationsJL
17+
```
Lines changed: 176 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,176 @@
1+
module NonlinearSolveSIAMFANLEquationsExt
2+
3+
using NonlinearSolve, SciMLBase
4+
using SIAMFANLEquations
5+
import ConcreteStructs: @concrete
6+
import UnPack: @unpack
7+
import FiniteDiff, ForwardDiff
8+
9+
function SciMLBase.__solve(prob::NonlinearProblem, alg::SIAMFANLEquationsJL, args...; abstol = 1e-8,
10+
reltol = 1e-8, alias_u0::Bool = false, maxiters = 1000, kwargs...)
11+
@unpack method, autodiff, show_trace, delta, linsolve = alg
12+
13+
iip = SciMLBase.isinplace(prob)
14+
if typeof(prob.u0) <: Number
15+
f! = if iip
16+
function (u)
17+
du = similar(u)
18+
prob.f(du, u, prob.p)
19+
return du
20+
end
21+
else
22+
u -> prob.f(u, prob.p)
23+
end
24+
25+
if method == :newton
26+
res = nsolsc(f!, prob.u0; maxit = maxiters, atol = abstol, rtol = reltol, printerr = show_trace)
27+
elseif method == :pseudotransient
28+
res = ptcsolsc(f!, prob.u0; delta0 = delta, maxit = maxiters, atol = abstol, rtol=reltol, printerr = show_trace)
29+
elseif method == :secant
30+
res = secant(f!, prob.u0; maxit = maxiters, atol = abstol, rtol = reltol, printerr = show_trace)
31+
end
32+
33+
if res.errcode == 0
34+
retcode = ReturnCode.Success
35+
elseif res.errcode == 10
36+
retcode = ReturnCode.MaxIters
37+
elseif res.errcode == 1
38+
retcode = ReturnCode.Failure
39+
@error("Line search failed")
40+
elseif res.errcode == -1
41+
retcode = ReturnCode.Default
42+
@info("Initial iterate satisfies the termination criteria")
43+
end
44+
stats = method == :pseudotransient ? nothing : (SciMLBase.NLStats(res.stats.ifun[1], res.stats.ijac[1], -1, -1, res.stats.iarm[1]))
45+
return SciMLBase.build_solution(prob, alg, res.solution, res.history; retcode, stats)
46+
else
47+
u = NonlinearSolve.__maybe_unaliased(prob.u0, alias_u0)
48+
end
49+
50+
fu = NonlinearSolve.evaluate_f(prob, u)
51+
52+
if iip
53+
f! = function (du, u)
54+
prob.f(du, u, prob.p)
55+
return du
56+
end
57+
else
58+
f! = function (du, u)
59+
du .= prob.f(u, prob.p)
60+
return du
61+
end
62+
end
63+
64+
# Allocate ahead for function and Jacobian
65+
N = length(u)
66+
FS = zeros(eltype(u), N)
67+
FPS = zeros(eltype(u), N, N)
68+
# Allocate ahead for Krylov basis
69+
70+
# Jacobian free Newton Krylov
71+
if linsolve !== nothing
72+
JVS = linsolve == :gmres ? zeros(eltype(u), N, 3) : zeros(eltype(u), N)
73+
# `linsolve` as a Symbol to keep unified interface with other EXTs, SIAMFANLEquations directly use String to choose between linear solvers
74+
linsolve_alg = strip(repr(linsolve), ':')
75+
76+
if method == :newton
77+
res = nsoli(f!, u, FS, JVS; lsolver = linsolve_alg, maxit = maxiters, atol = abstol, rtol = reltol, printerr = show_trace)
78+
elseif method == :pseudotransient
79+
res = ptcsoli(f!, u, FS, JVS; lsolver = linsolve_alg, maxit = maxiters, atol = abstol, rtol = reltol, printerr = show_trace)
80+
end
81+
82+
if res.errcode == 0
83+
retcode = ReturnCode.Success
84+
elseif res.errcode == 10
85+
retcode = ReturnCode.MaxIters
86+
elseif res.errcode == 1
87+
retcode = ReturnCode.Failure
88+
@error("Line search failed")
89+
elseif res.errcode == -1
90+
retcode = ReturnCode.Default
91+
@info("Initial iterate satisfies the termination criteria")
92+
end
93+
stats = method == :pseudotransient ? nothing : (SciMLBase.NLStats(res.stats.ifun[1], res.stats.ijac[1], -1, -1, res.stats.iarm[1]))
94+
return SciMLBase.build_solution(prob, alg, res.solution, res.history; retcode, stats)
95+
end
96+
97+
if prob.f.jac === nothing
98+
use_forward_diff = if alg.autodiff === nothing
99+
ForwardDiff.can_dual(eltype(u))
100+
else
101+
alg.autodiff isa AutoForwardDiff
102+
end
103+
uf = SciMLBase.JacobianWrapper{iip}(prob.f, prob.p)
104+
if use_forward_diff
105+
cache = iip ? ForwardDiff.JacobianConfig(uf, fu, u) :
106+
ForwardDiff.JacobianConfig(uf, u)
107+
else
108+
cache = FiniteDiff.JacobianCache(u, fu)
109+
end
110+
J! = if iip
111+
if use_forward_diff
112+
fu_cache = similar(fu)
113+
function (J, x, p)
114+
uf.p = p
115+
ForwardDiff.jacobian!(J, uf, fu_cache, x, cache)
116+
return J
117+
end
118+
else
119+
function (J, x, p)
120+
uf.p = p
121+
FiniteDiff.finite_difference_jacobian!(J, uf, x, cache)
122+
return J
123+
end
124+
end
125+
else
126+
if use_forward_diff
127+
function (J, x, p)
128+
uf.p = p
129+
ForwardDiff.jacobian!(J, uf, x, cache)
130+
return J
131+
end
132+
else
133+
function (J, x, p)
134+
uf.p = p
135+
J_ = FiniteDiff.finite_difference_jacobian(uf, x, cache)
136+
copyto!(J, J_)
137+
return J
138+
end
139+
end
140+
end
141+
else
142+
J! = prob.f.jac
143+
end
144+
145+
AJ!(J, u, x) = J!(J, x, prob.p)
146+
147+
if method == :newton
148+
res = nsol(f!, u, FS, FPS, AJ!;
149+
sham=1, rtol = reltol, atol = abstol, maxit = maxiters,
150+
printerr = show_trace)
151+
elseif method == :pseudotransient
152+
res = ptcsol(f!, u, FS, FPS, AJ!;
153+
rtol = reltol, atol = abstol, maxit = maxiters,
154+
delta0 = delta, printerr = show_trace)
155+
156+
end
157+
158+
if res.errcode == 0
159+
retcode = ReturnCode.Success
160+
elseif res.errcode == 10
161+
retcode = ReturnCode.MaxIters
162+
elseif res.errcode == 1
163+
retcode = ReturnCode.Failure
164+
@error("Line search failed")
165+
elseif res.errcode == -1
166+
retcode = ReturnCode.Default
167+
@info("Initial iterate satisfies the termination criteria")
168+
end
169+
170+
171+
# pseudo transient continuation has a fixed cost per iteration, iteration statistics are not interesting here.
172+
stats = method == :pseudotransient ? nothing : (SciMLBase.NLStats(res.stats.ifun[1], res.stats.ijac[1], -1, -1, res.stats.iarm[1]))
173+
return SciMLBase.build_solution(prob, alg, res.solution, res.history; retcode, stats)
174+
end
175+
176+
end

src/NonlinearSolve.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,7 @@ export RadiusUpdateSchemes
236236

237237
export NewtonRaphson, TrustRegion, LevenbergMarquardt, DFSane, GaussNewton, PseudoTransient,
238238
Broyden, Klement, LimitedMemoryBroyden
239-
export LeastSquaresOptimJL, FastLevenbergMarquardtJL, CMINPACK, NLsolveJL
239+
export LeastSquaresOptimJL, FastLevenbergMarquardtJL, CMINPACK, NLsolveJL, SIAMFANLEquationsJL
240240
export NonlinearSolvePolyAlgorithm,
241241
RobustMultiNewton, FastShortcutNonlinearPolyalg, FastShortcutNLLSPolyalg
242242

src/extension_algs.jl

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,3 +206,35 @@ function NLsolveJL(; method = :trust_region, autodiff = :central, store_trace =
206206
return NLsolveJL(method, autodiff, store_trace, extended_trace, linesearch, linsolve,
207207
factor, autoscale, m, beta, show_trace)
208208
end
209+
210+
"""
211+
SIAMFANLEquationsJL(; method = :newton, autodiff = :central)
212+
213+
### Keyword Arguments
214+
215+
- `method`: the choice of method for solving the nonlinear system.
216+
- `autodiff`: the choice of method for generating the Jacobian. Defaults to `:central` or
217+
central differencing via FiniteDiff.jl. The other choices are `:forward`.
218+
- `show_trace`: whether to show the trace.
219+
- `delta`: initial pseudo time step, default is 1e-3.
220+
- `linsolve` : JFNK linear solvers, choices are `gmres` and `bicgstab`.
221+
222+
### Submethod Choice
223+
224+
- `:newton`: Classical Newton method.
225+
- `:pseudotransient`:
226+
"""
227+
@concrete struct SIAMFANLEquationsJL <: AbstractNonlinearAlgorithm
228+
method::Symbol
229+
autodiff::Symbol
230+
show_trace::Bool
231+
delta
232+
linsolve::Union{Symbol, Nothing}
233+
end
234+
235+
function SIAMFANLEquationsJL(; method = :newton, autodiff = :central, show_trace = false, delta = 1e-3, linsolve = nothing)
236+
if Base.get_extension(@__MODULE__, :NonlinearSolveSIAMFANLEquationsExt) === nothing
237+
error("SIAMFANLEquationsJL requires SIAMFANLEquations.jl to be loaded")
238+
end
239+
return SIAMFANLEquationsJL(method, autodiff, show_trace, delta, linsolve)
240+
end

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ end
2323
if GROUP == "All" || GROUP == "Wrappers"
2424
@time @safetestset "MINPACK" include("minpack.jl")
2525
@time @safetestset "NLsolve" include("nlsolve.jl")
26+
@time @safetestset "SIAMFANLEquations" include("siamfanlequations.jl")
2627
end
2728

2829
if GROUP == "All" || GROUP == "23TestProblems"

0 commit comments

Comments
 (0)