@@ -6,14 +6,19 @@ import ConcreteStructs: @concrete
6
6
import UnPack: @unpack
7
7
import FiniteDiff, ForwardDiff
8
8
9
- function SciMLBase. __solve (prob:: NonlinearProblem , alg:: SIAMFANLEquationsJL , args... ; abstol = 1e-8 ,
10
- reltol = 1e-8 , alias_u0:: Bool = false , maxiters = 1000 , termination_condition = nothing , kwargs... )
9
+ function SciMLBase. __solve (prob:: NonlinearProblem , alg:: SIAMFANLEquationsJL , args... ; abstol = nothing ,
10
+ reltol = nothing , alias_u0:: Bool = false , maxiters = 1000 , termination_condition = nothing , kwargs... )
11
11
@assert (termination_condition === nothing ) || (termination_condition isa AbsNormTerminationMode) " SIAMFANLEquationsJL does not support termination conditions!"
12
12
13
13
@unpack method, autodiff, show_trace, delta, linsolve = alg
14
14
15
15
iip = SciMLBase. isinplace (prob)
16
- if typeof (prob. u0) <: Number
16
+ T = eltype (u0)
17
+
18
+ atol = abstol === nothing ? real (oneunit (T)) * (eps (real (one (T))))^ (4 // 5 ) : abstol
19
+ rtol = reltol === nothing ? real (oneunit (T)) * (eps (real (one (T))))^ (4 // 5 ) : reltol
20
+
21
+ if prob. u0 isa Number
17
22
f! = if iip
18
23
function (u)
19
24
du = similar (u)
@@ -25,11 +30,11 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::SIAMFANLEquationsJL, arg
25
30
end
26
31
27
32
if method == :newton
28
- sol = nsolsc (f!, prob. u0; maxit = maxiters, atol = abstol , rtol = reltol , printerr = show_trace)
33
+ sol = nsolsc (f!, prob. u0; maxit = maxiters, atol = atol , rtol = rtol , printerr = show_trace)
29
34
elseif method == :pseudotransient
30
- sol = ptcsolsc (f!, prob. u0; delta0 = delta, maxit = maxiters, atol = abstol , rtol= reltol , printerr = show_trace)
35
+ sol = ptcsolsc (f!, prob. u0; delta0 = delta, maxit = maxiters, atol = atol , rtol= rtol , printerr = show_trace)
31
36
elseif method == :secant
32
- sol = secant (f!, prob. u0; maxit = maxiters, atol = abstol , rtol = reltol , printerr = show_trace)
37
+ sol = secant (f!, prob. u0; maxit = maxiters, atol = atol , rtol = rtol , printerr = show_trace)
33
38
end
34
39
35
40
if sol. errcode == 0
@@ -61,22 +66,21 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::SIAMFANLEquationsJL, arg
61
66
end
62
67
end
63
68
64
- # Allocate ahead for function and Jacobian
69
+ # Allocate ahead for function
65
70
N = length (u)
66
- FS = zeros (eltype (u), N)
67
- FPS = zeros (eltype (u), N, N)
68
- # Allocate ahead for Krylov basis
71
+ FS = zeros (T, N)
69
72
70
73
# Jacobian free Newton Krylov
71
74
if linsolve != = nothing
72
- JVS = linsolve == :gmres ? zeros (eltype (u), N, 3 ) : zeros (eltype (u), N)
75
+ # Allocate ahead for Krylov basis
76
+ JVS = linsolve == :gmres ? zeros (T, N, 3 ) : zeros (T, N)
73
77
# `linsolve` as a Symbol to keep unified interface with other EXTs, SIAMFANLEquations directly use String to choose between different linear solvers
74
78
linsolve_alg = String (linsolve)
75
79
76
80
if method == :newton
77
- sol = nsoli (f!, u, FS, JVS; lsolver = linsolve_alg, maxit = maxiters, atol = abstol , rtol = reltol , printerr = show_trace)
81
+ sol = nsoli (f!, u, FS, JVS; lsolver = linsolve_alg, maxit = maxiters, atol = atol , rtol = rtol , printerr = show_trace)
78
82
elseif method == :pseudotransient
79
- sol = ptcsoli (f!, u, FS, JVS; lsolver = linsolve_alg, maxit = maxiters, atol = abstol , rtol = reltol , printerr = show_trace)
83
+ sol = ptcsoli (f!, u, FS, JVS; lsolver = linsolve_alg, maxit = maxiters, atol = atol , rtol = rtol , printerr = show_trace)
80
84
end
81
85
82
86
if sol. errcode == 0
@@ -92,64 +96,30 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::SIAMFANLEquationsJL, arg
92
96
return SciMLBase. build_solution (prob, alg, sol. solution, sol. history; retcode, stats, original = sol)
93
97
end
94
98
99
+ # Allocate ahead for Jacobian
100
+ FPS = zeros (T, N, N)
95
101
if prob. f. jac === nothing
96
- use_forward_diff = if alg. autodiff === nothing
97
- ForwardDiff. can_dual (eltype (u))
98
- else
99
- alg. autodiff isa AutoForwardDiff
100
- end
101
- uf = SciMLBase. JacobianWrapper {iip} (prob. f, prob. p)
102
- if use_forward_diff
103
- cache = iip ? ForwardDiff. JacobianConfig (uf, fu, u) :
104
- ForwardDiff. JacobianConfig (uf, u)
105
- else
106
- cache = FiniteDiff. JacobianCache (u, fu)
107
- end
108
- J! = if iip
109
- if use_forward_diff
110
- fu_cache = similar (fu)
111
- function (J, x, p)
112
- uf. p = p
113
- ForwardDiff. jacobian! (J, uf, fu_cache, x, cache)
114
- return J
115
- end
116
- else
117
- function (J, x, p)
118
- uf. p = p
119
- FiniteDiff. finite_difference_jacobian! (J, uf, x, cache)
120
- return J
121
- end
122
- end
123
- else
124
- if use_forward_diff
125
- function (J, x, p)
126
- uf. p = p
127
- ForwardDiff. jacobian! (J, uf, x, cache)
128
- return J
129
- end
130
- else
131
- function (J, x, p)
132
- uf. p = p
133
- J_ = FiniteDiff. finite_difference_jacobian (uf, x, cache)
134
- copyto! (J, J_)
135
- return J
136
- end
137
- end
102
+ # Use the built-in Jacobian machinery
103
+ if method == :newton
104
+ sol = nsol (f!, u, FS, FPS;
105
+ sham= 1 , atol = atol, rtol = rtol, maxit = maxiters,
106
+ printerr = show_trace)
107
+ elseif method == :pseudotransient
108
+ sol = ptcsol (f!, u, FS, FPS;
109
+ atol = atol, rtol = rtol, maxit = maxiters,
110
+ delta0 = delta, printerr = show_trace)
138
111
end
139
112
else
140
- J! = prob. f. jac
141
- end
142
-
143
- AJ! (J, u, x) = J! (J, x, prob. p)
144
-
145
- if method == :newton
146
- sol = nsol (f!, u, FS, FPS, AJ!;
147
- sham= 1 , rtol = reltol, atol = abstol, maxit = maxiters,
148
- printerr = show_trace)
149
- elseif method == :pseudotransient
150
- sol = ptcsol (f!, u, FS, FPS, AJ!;
151
- rtol = reltol, atol = abstol, maxit = maxiters,
152
- delta0 = delta, printerr = show_trace)
113
+ AJ! (J, u, x) = prob. f. jac (J, x, prob. p)
114
+ if method == :newton
115
+ sol = nsol (f!, u, FS, FPS, AJ!;
116
+ sham= 1 , atol = atol, rtol = rtol, maxit = maxiters,
117
+ printerr = show_trace)
118
+ elseif method == :pseudotransient
119
+ sol = ptcsol (f!, u, FS, FPS, AJ!;
120
+ atol = atol, rtol = rtol, maxit = maxiters,
121
+ delta0 = delta, printerr = show_trace)
122
+ end
153
123
end
154
124
155
125
if sol. errcode == 0
0 commit comments