Skip to content

Commit 8ee980b

Browse files
committed
Fix abstol, reltol and use the built-in Jacobian
Signed-off-by: ErikQQY <2283984853@qq.com>
1 parent ce52260 commit 8ee980b

File tree

1 file changed

+38
-68
lines changed

1 file changed

+38
-68
lines changed

ext/NonlinearSolveSIAMFANLEquationsExt.jl

Lines changed: 38 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,19 @@ import ConcreteStructs: @concrete
66
import UnPack: @unpack
77
import FiniteDiff, ForwardDiff
88

9-
function SciMLBase.__solve(prob::NonlinearProblem, alg::SIAMFANLEquationsJL, args...; abstol = 1e-8,
10-
reltol = 1e-8, alias_u0::Bool = false, maxiters = 1000, termination_condition = nothing, kwargs...)
9+
function SciMLBase.__solve(prob::NonlinearProblem, alg::SIAMFANLEquationsJL, args...; abstol = nothing,
10+
reltol = nothing, alias_u0::Bool = false, maxiters = 1000, termination_condition = nothing, kwargs...)
1111
@assert (termination_condition === nothing) || (termination_condition isa AbsNormTerminationMode) "SIAMFANLEquationsJL does not support termination conditions!"
1212

1313
@unpack method, autodiff, show_trace, delta, linsolve = alg
1414

1515
iip = SciMLBase.isinplace(prob)
16-
if typeof(prob.u0) <: Number
16+
T = eltype(u0)
17+
18+
atol = abstol === nothing ? real(oneunit(T)) * (eps(real(one(T))))^(4 // 5) : abstol
19+
rtol = reltol === nothing ? real(oneunit(T)) * (eps(real(one(T))))^(4 // 5) : reltol
20+
21+
if prob.u0 isa Number
1722
f! = if iip
1823
function (u)
1924
du = similar(u)
@@ -25,11 +30,11 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::SIAMFANLEquationsJL, arg
2530
end
2631

2732
if method == :newton
28-
sol = nsolsc(f!, prob.u0; maxit = maxiters, atol = abstol, rtol = reltol, printerr = show_trace)
33+
sol = nsolsc(f!, prob.u0; maxit = maxiters, atol = atol, rtol = rtol, printerr = show_trace)
2934
elseif method == :pseudotransient
30-
sol = ptcsolsc(f!, prob.u0; delta0 = delta, maxit = maxiters, atol = abstol, rtol=reltol, printerr = show_trace)
35+
sol = ptcsolsc(f!, prob.u0; delta0 = delta, maxit = maxiters, atol = atol, rtol=rtol, printerr = show_trace)
3136
elseif method == :secant
32-
sol = secant(f!, prob.u0; maxit = maxiters, atol = abstol, rtol = reltol, printerr = show_trace)
37+
sol = secant(f!, prob.u0; maxit = maxiters, atol = atol, rtol = rtol, printerr = show_trace)
3338
end
3439

3540
if sol.errcode == 0
@@ -61,22 +66,21 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::SIAMFANLEquationsJL, arg
6166
end
6267
end
6368

64-
# Allocate ahead for function and Jacobian
69+
# Allocate ahead for function
6570
N = length(u)
66-
FS = zeros(eltype(u), N)
67-
FPS = zeros(eltype(u), N, N)
68-
# Allocate ahead for Krylov basis
71+
FS = zeros(T, N)
6972

7073
# Jacobian free Newton Krylov
7174
if linsolve !== nothing
72-
JVS = linsolve == :gmres ? zeros(eltype(u), N, 3) : zeros(eltype(u), N)
75+
# Allocate ahead for Krylov basis
76+
JVS = linsolve == :gmres ? zeros(T, N, 3) : zeros(T, N)
7377
# `linsolve` as a Symbol to keep unified interface with other EXTs, SIAMFANLEquations directly use String to choose between different linear solvers
7478
linsolve_alg = String(linsolve)
7579

7680
if method == :newton
77-
sol = nsoli(f!, u, FS, JVS; lsolver = linsolve_alg, maxit = maxiters, atol = abstol, rtol = reltol, printerr = show_trace)
81+
sol = nsoli(f!, u, FS, JVS; lsolver = linsolve_alg, maxit = maxiters, atol = atol, rtol = rtol, printerr = show_trace)
7882
elseif method == :pseudotransient
79-
sol = ptcsoli(f!, u, FS, JVS; lsolver = linsolve_alg, maxit = maxiters, atol = abstol, rtol = reltol, printerr = show_trace)
83+
sol = ptcsoli(f!, u, FS, JVS; lsolver = linsolve_alg, maxit = maxiters, atol = atol, rtol = rtol, printerr = show_trace)
8084
end
8185

8286
if sol.errcode == 0
@@ -92,64 +96,30 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::SIAMFANLEquationsJL, arg
9296
return SciMLBase.build_solution(prob, alg, sol.solution, sol.history; retcode, stats, original = sol)
9397
end
9498

99+
# Allocate ahead for Jacobian
100+
FPS = zeros(T, N, N)
95101
if prob.f.jac === nothing
96-
use_forward_diff = if alg.autodiff === nothing
97-
ForwardDiff.can_dual(eltype(u))
98-
else
99-
alg.autodiff isa AutoForwardDiff
100-
end
101-
uf = SciMLBase.JacobianWrapper{iip}(prob.f, prob.p)
102-
if use_forward_diff
103-
cache = iip ? ForwardDiff.JacobianConfig(uf, fu, u) :
104-
ForwardDiff.JacobianConfig(uf, u)
105-
else
106-
cache = FiniteDiff.JacobianCache(u, fu)
107-
end
108-
J! = if iip
109-
if use_forward_diff
110-
fu_cache = similar(fu)
111-
function (J, x, p)
112-
uf.p = p
113-
ForwardDiff.jacobian!(J, uf, fu_cache, x, cache)
114-
return J
115-
end
116-
else
117-
function (J, x, p)
118-
uf.p = p
119-
FiniteDiff.finite_difference_jacobian!(J, uf, x, cache)
120-
return J
121-
end
122-
end
123-
else
124-
if use_forward_diff
125-
function (J, x, p)
126-
uf.p = p
127-
ForwardDiff.jacobian!(J, uf, x, cache)
128-
return J
129-
end
130-
else
131-
function (J, x, p)
132-
uf.p = p
133-
J_ = FiniteDiff.finite_difference_jacobian(uf, x, cache)
134-
copyto!(J, J_)
135-
return J
136-
end
137-
end
102+
# Use the built-in Jacobian machinery
103+
if method == :newton
104+
sol = nsol(f!, u, FS, FPS;
105+
sham=1, atol = atol, rtol = rtol, maxit = maxiters,
106+
printerr = show_trace)
107+
elseif method == :pseudotransient
108+
sol = ptcsol(f!, u, FS, FPS;
109+
atol = atol, rtol = rtol, maxit = maxiters,
110+
delta0 = delta, printerr = show_trace)
138111
end
139112
else
140-
J! = prob.f.jac
141-
end
142-
143-
AJ!(J, u, x) = J!(J, x, prob.p)
144-
145-
if method == :newton
146-
sol = nsol(f!, u, FS, FPS, AJ!;
147-
sham=1, rtol = reltol, atol = abstol, maxit = maxiters,
148-
printerr = show_trace)
149-
elseif method == :pseudotransient
150-
sol = ptcsol(f!, u, FS, FPS, AJ!;
151-
rtol = reltol, atol = abstol, maxit = maxiters,
152-
delta0 = delta, printerr = show_trace)
113+
AJ!(J, u, x) = prob.f.jac(J, x, prob.p)
114+
if method == :newton
115+
sol = nsol(f!, u, FS, FPS, AJ!;
116+
sham=1, atol = atol, rtol = rtol, maxit = maxiters,
117+
printerr = show_trace)
118+
elseif method == :pseudotransient
119+
sol = ptcsol(f!, u, FS, FPS, AJ!;
120+
atol = atol, rtol = rtol, maxit = maxiters,
121+
delta0 = delta, printerr = show_trace)
122+
end
153123
end
154124

155125
if sol.errcode == 0

0 commit comments

Comments
 (0)