Skip to content

Commit f2edda0

Browse files
committed
Standardize parts of SIAM FANL Equations
1 parent 73005ca commit f2edda0

File tree

7 files changed

+83
-87
lines changed

7 files changed

+83
-87
lines changed

docs/src/api/siamfanlequations.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# SIAMFANLEquations.jl
22

3-
This is an extension for importing solvers from [SIAMFANLEquations.jl](https://github.com/ctkelley/SIAMFANLEquations.jl) into the SciML
3+
This is an extension for importing solvers from
4+
[SIAMFANLEquations.jl](https://github.com/ctkelley/SIAMFANLEquations.jl) into the SciML
45
interface. Note that these solvers do not come by default, and thus one needs to install
56
the package before using these solvers:
67

docs/src/solvers/NonlinearSystemSolvers.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,3 +143,10 @@ Newton-Krylov form. However, KINSOL is known to be less stable than some other
143143
implementations, as it has no line search or globalizer (trust region).
144144

145145
- `KINSOL()`: The KINSOL method of the SUNDIALS C library
146+
147+
### SIAMFANLEquations.jl
148+
149+
SIAMFANLEquations.jl is a wrapper for the methods in the SIAMFANLEquations.jl library.
150+
151+
- `SIAMFANLEquationsJL()`: A wrapper for using the methods in
152+
[SIAMFANLEquations.jl](https://github.com/ctkelley/SIAMFANLEquations.jl)

ext/NonlinearSolveNLsolveExt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::NLsolveJL, args...;
6868
df = OnceDifferentiable(f!, vec(u0), vec(resid); autodiff)
6969
end
7070

71-
abstol = NonlinearSolve.DEFAULT_TOLERANCE(abstol, eltype(u))
71+
abstol = NonlinearSolve.DEFAULT_TOLERANCE(abstol, eltype(u0))
7272

7373
original = nlsolve(df, vec(u0); ftol = abstol, iterations = maxiters, method,
7474
store_trace, extended_trace, linesearch, linsolve, factor, autoscale, m, beta,

ext/NonlinearSolveSIAMFANLEquationsExt.jl

Lines changed: 62 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -2,51 +2,58 @@ module NonlinearSolveSIAMFANLEquationsExt
22

33
using NonlinearSolve, SciMLBase
44
using SIAMFANLEquations
5-
import ConcreteStructs: @concrete
65
import UnPack: @unpack
76

8-
function SciMLBase.__solve(prob::NonlinearProblem, alg::SIAMFANLEquationsJL, args...; abstol = nothing,
9-
reltol = nothing, alias_u0::Bool = false, maxiters = 1000, termination_condition = nothing, kwargs...)
10-
@assert (termination_condition === nothing) || (termination_condition isa AbsNormTerminationMode) "SIAMFANLEquationsJL does not support termination conditions!"
7+
@inline function __siam_fanl_equations_retcode_mapping(sol)
8+
if sol.errcode == 0
9+
return ReturnCode.Success
10+
elseif sol.errcode == 10
11+
return ReturnCode.MaxIters
12+
elseif sol.errcode == 1
13+
return ReturnCode.Failure
14+
elseif sol.errcode == -1
15+
return ReturnCode.Default
16+
end
17+
end
18+
19+
# pseudo transient continuation has a fixed cost per iteration, iteration statistics are
20+
# not interesting here.
21+
@inline function __siam_fanl_equations_stats_mapping(method, sol)
22+
method === :pseudotransient && return nothing
23+
return SciMLBase.NLStats(sum(sol.stats.ifun), sum(sol.stats.ijac), 0, 0,
24+
sum(sol.stats.iarm))
25+
end
1126

12-
@unpack method, show_trace, delta, linsolve = alg
27+
function SciMLBase.__solve(prob::NonlinearProblem, alg::SIAMFANLEquationsJL, args...;
28+
abstol = nothing, reltol = nothing, alias_u0::Bool = false, maxiters = 1000,
29+
termination_condition = nothing, show_trace::Val{ShT} = Val(false),
30+
kwargs...) where {ShT}
31+
@assert (termination_condition ===
32+
nothing)||(termination_condition isa AbsNormTerminationMode) "SIAMFANLEquationsJL does not support termination conditions!"
33+
34+
@unpack method, delta, linsolve = alg
1335

1436
iip = SciMLBase.isinplace(prob)
15-
T = eltype(prob.u0)
1637

17-
atol = abstol === nothing ? real(oneunit(T)) * (eps(real(one(T))))^(4 // 5) : abstol
18-
rtol = reltol === nothing ? real(oneunit(T)) * (eps(real(one(T))))^(4 // 5) : reltol
38+
atol = NonlinearSolve.DEFAULT_TOLERANCE(abstol, eltype(prob.u0))
39+
rtol = NonlinearSolve.DEFAULT_TOLERANCE(reltol, eltype(prob.u0))
1940

2041
if prob.u0 isa Number
21-
f! = if iip
22-
function (u)
23-
du = similar(u)
24-
prob.f(du, u, prob.p)
25-
return du
26-
end
27-
else
28-
u -> prob.f(u, prob.p)
29-
end
42+
f = (u) -> prob.f(u, prob.p)
3043

3144
if method == :newton
32-
sol = nsolsc(f!, prob.u0; maxit = maxiters, atol = atol, rtol = rtol, printerr = show_trace)
45+
sol = nsolsc(f, prob.u0; maxit = maxiters, atol, rtol, printerr = ShT)
3346
elseif method == :pseudotransient
34-
sol = ptcsolsc(f!, prob.u0; delta0 = delta, maxit = maxiters, atol = atol, rtol=rtol, printerr = show_trace)
47+
sol = ptcsolsc(f, prob.u0; delta0 = delta, maxit = maxiters, atol, rtol,
48+
printerr = ShT)
3549
elseif method == :secant
36-
sol = secant(f!, prob.u0; maxit = maxiters, atol = atol, rtol = rtol, printerr = show_trace)
50+
sol = secant(f, prob.u0; maxit = maxiters, atol, rtol, printerr = ShT)
3751
end
3852

39-
if sol.errcode == 0
40-
retcode = ReturnCode.Success
41-
elseif sol.errcode == 10
42-
retcode = ReturnCode.MaxIters
43-
elseif sol.errcode == 1
44-
retcode = ReturnCode.Failure
45-
elseif sol.errcode == -1
46-
retcode = ReturnCode.Default
47-
end
48-
stats = method == :pseudotransient ? nothing : (SciMLBase.NLStats(sum(sol.stats.ifun), sum(sol.stats.ijac), 0, 0, sum(sol.stats.iarm)))
49-
return SciMLBase.build_solution(prob, alg, sol.solution, sol.history; retcode, stats, original = sol)
53+
retcode = __siam_fanl_equations_retcode_mapping(sol)
54+
stats = __siam_fanl_equations_stats_mapping(method, sol)
55+
return SciMLBase.build_solution(prob, alg, sol.solution, sol.history; retcode,
56+
stats, original = sol)
5057
else
5158
u = NonlinearSolve.__maybe_unaliased(prob.u0, alias_u0)
5259
end
@@ -71,67 +78,50 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::SIAMFANLEquationsJL, arg
7178
if linsolve !== nothing
7279
# Allocate ahead for Krylov basis
7380
JVS = linsolve == :gmres ? zeros(T, N, 3) : zeros(T, N)
74-
# `linsolve` as a Symbol to keep unified interface with other EXTs, SIAMFANLEquations directly use String to choose between different linear solvers
81+
# `linsolve` as a Symbol to keep unified interface with other EXTs,
82+
# SIAMFANLEquations directly use String to choose between different linear solvers
7583
linsolve_alg = String(linsolve)
7684

7785
if method == :newton
78-
sol = nsoli(f!, u, FS, JVS; lsolver = linsolve_alg, maxit = maxiters, atol = atol, rtol = rtol, printerr = show_trace)
86+
sol = nsoli(f!, u, FS, JVS; lsolver = linsolve_alg, maxit = maxiters, atol,
87+
rtol, printerr = ShT)
7988
elseif method == :pseudotransient
80-
sol = ptcsoli(f!, u, FS, JVS; lsolver = linsolve_alg, maxit = maxiters, atol = atol, rtol = rtol, printerr = show_trace)
81-
end
82-
83-
if sol.errcode == 0
84-
retcode = ReturnCode.Success
85-
elseif sol.errcode == 10
86-
retcode = ReturnCode.MaxIters
87-
elseif sol.errcode == 1
88-
retcode = ReturnCode.Failure
89-
elseif sol.errcode == -1
90-
retcode = ReturnCode.Default
89+
sol = ptcsoli(f!, u, FS, JVS; lsolver = linsolve_alg, maxit = maxiters, atol,
90+
rtol, printerr = ShT)
9191
end
92-
stats = method == :pseudotransient ? nothing : (SciMLBase.NLStats(sum(sol.stats.ifun), sum(sol.stats.ijac), 0, 0, sum(sol.stats.iarm)))
93-
return SciMLBase.build_solution(prob, alg, sol.solution, sol.history; retcode, stats, original = sol)
92+
93+
retcode = __siam_fanl_equations_retcode_mapping(sol)
94+
stats = __siam_fanl_equations_stats_mapping(method, sol)
95+
return SciMLBase.build_solution(prob, alg, sol.solution, sol.history; retcode,
96+
stats, original = sol)
9497
end
9598

9699
# Allocate ahead for Jacobian
97100
FPS = zeros(T, N, N)
98101
if prob.f.jac === nothing
99102
# Use the built-in Jacobian machinery
100103
if method == :newton
101-
sol = nsol(f!, u, FS, FPS;
102-
sham=1, atol = atol, rtol = rtol, maxit = maxiters,
103-
printerr = show_trace)
104+
sol = nsol(f!, u, FS, FPS; sham = 1, atol, rtol, maxit = maxiters,
105+
printerr = ShT)
104106
elseif method == :pseudotransient
105-
sol = ptcsol(f!, u, FS, FPS;
106-
atol = atol, rtol = rtol, maxit = maxiters,
107-
delta0 = delta, printerr = show_trace)
107+
sol = ptcsol(f!, u, FS, FPS; atol, rtol, maxit = maxiters,
108+
delta0 = delta, printerr = ShT)
108109
end
109110
else
110111
AJ!(J, u, x) = prob.f.jac(J, x, prob.p)
111112
if method == :newton
112-
sol = nsol(f!, u, FS, FPS, AJ!;
113-
sham=1, atol = atol, rtol = rtol, maxit = maxiters,
114-
printerr = show_trace)
113+
sol = nsol(f!, u, FS, FPS, AJ!; sham = 1, atol, rtol, maxit = maxiters,
114+
printerr = ShT)
115115
elseif method == :pseudotransient
116-
sol = ptcsol(f!, u, FS, FPS, AJ!;
117-
atol = atol, rtol = rtol, maxit = maxiters,
118-
delta0 = delta, printerr = show_trace)
116+
sol = ptcsol(f!, u, FS, FPS, AJ!; atol, rtol, maxit = maxiters,
117+
delta0 = delta, printerr = ShT)
119118
end
120119
end
121120

122-
if sol.errcode == 0
123-
retcode = ReturnCode.Success
124-
elseif sol.errcode == 10
125-
retcode = ReturnCode.MaxIters
126-
elseif sol.errcode == 1
127-
retcode = ReturnCode.Failure
128-
elseif sol.errcode == -1
129-
retcode = ReturnCode.Default
130-
end
131-
132-
# pseudo transient continuation has a fixed cost per iteration, iteration statistics are not interesting here.
133-
stats = method == :pseudotransient ? nothing : (SciMLBase.NLStats(sum(sol.stats.ifun), sum(sol.stats.ijac), 0, 0, sum(sol.stats.iarm)))
134-
return SciMLBase.build_solution(prob, alg, sol.solution, sol.history; retcode, stats, original = sol)
121+
retcode = __siam_fanl_equations_retcode_mapping(sol)
122+
stats = __siam_fanl_equations_stats_mapping(method, sol)
123+
return SciMLBase.build_solution(prob, alg, sol.solution, sol.history; retcode, stats,
124+
original = sol)
135125
end
136126

137-
end
127+
end

src/NonlinearSolve.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,8 @@ export RadiusUpdateSchemes
237237
export NewtonRaphson, TrustRegion, LevenbergMarquardt, DFSane, GaussNewton, PseudoTransient,
238238
Broyden, Klement, LimitedMemoryBroyden
239239
export LeastSquaresOptimJL,
240-
FastLevenbergMarquardtJL, CMINPACK, NLsolveJL, FixedPointAccelerationJL, SpeedMappingJL, SIAMFANLEquationsJL
240+
FastLevenbergMarquardtJL, CMINPACK, NLsolveJL, FixedPointAccelerationJL, SpeedMappingJL,
241+
SIAMFANLEquationsJL
241242
export NonlinearSolvePolyAlgorithm,
242243
RobustMultiNewton, FastShortcutNonlinearPolyalg, FastShortcutNLLSPolyalg
243244

src/extension_algs.jl

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,6 @@ function NLsolveJL(; method = :trust_region, autodiff = :central, store_trace =
247247
end
248248

249249
"""
250-
251250
SpeedMappingJL(; σ_min = 0.0, stabilize::Bool = false, check_obj::Bool = false,
252251
orders::Vector{Int} = [3, 3, 2], time_limit::Real = 1000)
253252
@@ -364,13 +363,11 @@ function FixedPointAccelerationJL(; algorithm = :Anderson, m = missing,
364363
end
365364

366365
"""
367-
368-
SIAMFANLEquationsJL(; method = :newton, autodiff = :central, show_trace = false, delta = 1e-3, linsolve = nothing)
366+
SIAMFANLEquationsJL(; method = :newton, delta = 1e-3, linsolve = nothing)
369367
370368
### Keyword Arguments
371369
372370
- `method`: the choice of method for solving the nonlinear system.
373-
- `show_trace`: whether to show the trace.
374371
- `delta`: initial pseudo time step, default is 1e-3.
375372
- `linsolve` : JFNK linear solvers, choices are `gmres` and `bicgstab`.
376373
@@ -380,16 +377,16 @@ end
380377
- `:pseudotransient`: Pseudo transient method.
381378
- `:secant`: Secant method for scalar equations.
382379
"""
383-
@concrete struct SIAMFANLEquationsJL <: AbstractNonlinearAlgorithm
380+
@concrete struct SIAMFANLEquationsJL{L <: Union{Symbol, Nothing}} <:
381+
AbstractNonlinearSolveAlgorithm
384382
method::Symbol
385-
show_trace::Bool
386383
delta
387-
linsolve::Union{Symbol, Nothing}
384+
linsolve::L
388385
end
389386

390-
function SIAMFANLEquationsJL(; method = :newton, show_trace = false, delta = 1e-3, linsolve = nothing)
387+
function SIAMFANLEquationsJL(; method = :newton, delta = 1e-3, linsolve = nothing)
391388
if Base.get_extension(@__MODULE__, :NonlinearSolveSIAMFANLEquationsExt) === nothing
392389
error("SIAMFANLEquationsJL requires SIAMFANLEquations.jl to be loaded")
393390
end
394-
return SIAMFANLEquationsJL(method, show_trace, delta, linsolve)
391+
return SIAMFANLEquationsJL(method, show_trace, delta, linsolve)
395392
end

test/siamfanlequations.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ end
7474
f_tol(u, p) = u^2 - 2
7575
prob_tol = NonlinearProblem(f_tol, 1.0)
7676
for tol in [1e-1, 1e-3, 1e-6, 1e-10, 1e-11]
77-
for method = [:newton, :pseudotransient, :secant]
77+
for method in [:newton, :pseudotransient, :secant]
7878
sol = solve(prob_tol, SIAMFANLEquationsJL(method = method), abstol = tol)
7979
@test abs(sol.u[1] - sqrt(2)) < tol
8080
end
@@ -141,12 +141,12 @@ f = NonlinearFunction(f!, jac = j!)
141141
p = A
142142

143143
ProbN = NonlinearProblem(f, init, p)
144-
for method = [:newton, :pseudotransient]
144+
for method in [:newton, :pseudotransient]
145145
sol = solve(ProbN, SIAMFANLEquationsJL(method = method), reltol = 1e-8, abstol = 1e-8)
146146
end
147147

148148
#= doesn't support complex numbers handling
149149
init = ones(Complex{Float64}, 152);
150150
ProbN = NonlinearProblem(f, init, p)
151151
sol = solve(ProbN, SIAMFANLEquationsJL(), reltol = 1e-8, abstol = 1e-8)
152-
=#
152+
=#

0 commit comments

Comments
 (0)