@@ -2,51 +2,58 @@ module NonlinearSolveSIAMFANLEquationsExt
2
2
3
3
using NonlinearSolve, SciMLBase
4
4
using SIAMFANLEquations
5
- import ConcreteStructs: @concrete
6
5
import UnPack: @unpack
7
6
8
- function SciMLBase. __solve (prob:: NonlinearProblem , alg:: SIAMFANLEquationsJL , args... ; abstol = nothing ,
9
- reltol = nothing , alias_u0:: Bool = false , maxiters = 1000 , termination_condition = nothing , kwargs... )
10
- @assert (termination_condition === nothing ) || (termination_condition isa AbsNormTerminationMode) " SIAMFANLEquationsJL does not support termination conditions!"
7
+ @inline function __siam_fanl_equations_retcode_mapping (sol)
8
+ if sol. errcode == 0
9
+ return ReturnCode. Success
10
+ elseif sol. errcode == 10
11
+ return ReturnCode. MaxIters
12
+ elseif sol. errcode == 1
13
+ return ReturnCode. Failure
14
+ elseif sol. errcode == - 1
15
+ return ReturnCode. Default
16
+ end
17
+ end
18
+
19
+ # pseudo transient continuation has a fixed cost per iteration, iteration statistics are
20
+ # not interesting here.
21
+ @inline function __siam_fanl_equations_stats_mapping (method, sol)
22
+ method === :pseudotransient && return nothing
23
+ return SciMLBase. NLStats (sum (sol. stats. ifun), sum (sol. stats. ijac), 0 , 0 ,
24
+ sum (sol. stats. iarm))
25
+ end
11
26
12
- @unpack method, show_trace, delta, linsolve = alg
27
+ function SciMLBase. __solve (prob:: NonlinearProblem , alg:: SIAMFANLEquationsJL , args... ;
28
+ abstol = nothing , reltol = nothing , alias_u0:: Bool = false , maxiters = 1000 ,
29
+ termination_condition = nothing , show_trace:: Val{ShT} = Val (false ),
30
+ kwargs... ) where {ShT}
31
+ @assert (termination_condition ===
32
+ nothing )|| (termination_condition isa AbsNormTerminationMode) " SIAMFANLEquationsJL does not support termination conditions!"
33
+
34
+ @unpack method, delta, linsolve = alg
13
35
14
36
iip = SciMLBase. isinplace (prob)
15
- T = eltype (prob. u0)
16
37
17
- atol = abstol === nothing ? real ( oneunit (T)) * ( eps ( real ( one (T)))) ^ ( 4 // 5 ) : abstol
18
- rtol = reltol === nothing ? real ( oneunit (T)) * ( eps ( real ( one (T)))) ^ ( 4 // 5 ) : reltol
38
+ atol = NonlinearSolve . DEFAULT_TOLERANCE ( abstol, eltype (prob . u0))
39
+ rtol = NonlinearSolve . DEFAULT_TOLERANCE ( reltol, eltype (prob . u0))
19
40
20
41
if prob. u0 isa Number
21
- f! = if iip
22
- function (u)
23
- du = similar (u)
24
- prob. f (du, u, prob. p)
25
- return du
26
- end
27
- else
28
- u -> prob. f (u, prob. p)
29
- end
42
+ f = (u) -> prob. f (u, prob. p)
30
43
31
44
if method == :newton
32
- sol = nsolsc (f! , prob. u0; maxit = maxiters, atol = atol , rtol = rtol , printerr = show_trace )
45
+ sol = nsolsc (f, prob. u0; maxit = maxiters, atol, rtol, printerr = ShT )
33
46
elseif method == :pseudotransient
34
- sol = ptcsolsc (f!, prob. u0; delta0 = delta, maxit = maxiters, atol = atol, rtol= rtol, printerr = show_trace)
47
+ sol = ptcsolsc (f, prob. u0; delta0 = delta, maxit = maxiters, atol, rtol,
48
+ printerr = ShT)
35
49
elseif method == :secant
36
- sol = secant (f! , prob. u0; maxit = maxiters, atol = atol , rtol = rtol , printerr = show_trace )
50
+ sol = secant (f, prob. u0; maxit = maxiters, atol, rtol, printerr = ShT )
37
51
end
38
52
39
- if sol. errcode == 0
40
- retcode = ReturnCode. Success
41
- elseif sol. errcode == 10
42
- retcode = ReturnCode. MaxIters
43
- elseif sol. errcode == 1
44
- retcode = ReturnCode. Failure
45
- elseif sol. errcode == - 1
46
- retcode = ReturnCode. Default
47
- end
48
- stats = method == :pseudotransient ? nothing : (SciMLBase. NLStats (sum (sol. stats. ifun), sum (sol. stats. ijac), 0 , 0 , sum (sol. stats. iarm)))
49
- return SciMLBase. build_solution (prob, alg, sol. solution, sol. history; retcode, stats, original = sol)
53
+ retcode = __siam_fanl_equations_retcode_mapping (sol)
54
+ stats = __siam_fanl_equations_stats_mapping (method, sol)
55
+ return SciMLBase. build_solution (prob, alg, sol. solution, sol. history; retcode,
56
+ stats, original = sol)
50
57
else
51
58
u = NonlinearSolve. __maybe_unaliased (prob. u0, alias_u0)
52
59
end
@@ -71,67 +78,50 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::SIAMFANLEquationsJL, arg
71
78
if linsolve != = nothing
72
79
# Allocate ahead for Krylov basis
73
80
JVS = linsolve == :gmres ? zeros (T, N, 3 ) : zeros (T, N)
74
- # `linsolve` as a Symbol to keep unified interface with other EXTs, SIAMFANLEquations directly use String to choose between different linear solvers
81
+ # `linsolve` as a Symbol to keep unified interface with other EXTs,
82
+ # SIAMFANLEquations directly use String to choose between different linear solvers
75
83
linsolve_alg = String (linsolve)
76
84
77
85
if method == :newton
78
- sol = nsoli (f!, u, FS, JVS; lsolver = linsolve_alg, maxit = maxiters, atol = atol, rtol = rtol, printerr = show_trace)
86
+ sol = nsoli (f!, u, FS, JVS; lsolver = linsolve_alg, maxit = maxiters, atol,
87
+ rtol, printerr = ShT)
79
88
elseif method == :pseudotransient
80
- sol = ptcsoli (f!, u, FS, JVS; lsolver = linsolve_alg, maxit = maxiters, atol = atol, rtol = rtol, printerr = show_trace)
81
- end
82
-
83
- if sol. errcode == 0
84
- retcode = ReturnCode. Success
85
- elseif sol. errcode == 10
86
- retcode = ReturnCode. MaxIters
87
- elseif sol. errcode == 1
88
- retcode = ReturnCode. Failure
89
- elseif sol. errcode == - 1
90
- retcode = ReturnCode. Default
89
+ sol = ptcsoli (f!, u, FS, JVS; lsolver = linsolve_alg, maxit = maxiters, atol,
90
+ rtol, printerr = ShT)
91
91
end
92
- stats = method == :pseudotransient ? nothing : (SciMLBase. NLStats (sum (sol. stats. ifun), sum (sol. stats. ijac), 0 , 0 , sum (sol. stats. iarm)))
93
- return SciMLBase. build_solution (prob, alg, sol. solution, sol. history; retcode, stats, original = sol)
92
+
93
+ retcode = __siam_fanl_equations_retcode_mapping (sol)
94
+ stats = __siam_fanl_equations_stats_mapping (method, sol)
95
+ return SciMLBase. build_solution (prob, alg, sol. solution, sol. history; retcode,
96
+ stats, original = sol)
94
97
end
95
98
96
99
# Allocate ahead for Jacobian
97
100
FPS = zeros (T, N, N)
98
101
if prob. f. jac === nothing
99
102
# Use the built-in Jacobian machinery
100
103
if method == :newton
101
- sol = nsol (f!, u, FS, FPS;
102
- sham= 1 , atol = atol, rtol = rtol, maxit = maxiters,
103
- printerr = show_trace)
104
+ sol = nsol (f!, u, FS, FPS; sham = 1 , atol, rtol, maxit = maxiters,
105
+ printerr = ShT)
104
106
elseif method == :pseudotransient
105
- sol = ptcsol (f!, u, FS, FPS;
106
- atol = atol, rtol = rtol, maxit = maxiters,
107
- delta0 = delta, printerr = show_trace)
107
+ sol = ptcsol (f!, u, FS, FPS; atol, rtol, maxit = maxiters,
108
+ delta0 = delta, printerr = ShT)
108
109
end
109
110
else
110
111
AJ! (J, u, x) = prob. f. jac (J, x, prob. p)
111
112
if method == :newton
112
- sol = nsol (f!, u, FS, FPS, AJ!;
113
- sham= 1 , atol = atol, rtol = rtol, maxit = maxiters,
114
- printerr = show_trace)
113
+ sol = nsol (f!, u, FS, FPS, AJ!; sham = 1 , atol, rtol, maxit = maxiters,
114
+ printerr = ShT)
115
115
elseif method == :pseudotransient
116
- sol = ptcsol (f!, u, FS, FPS, AJ!;
117
- atol = atol, rtol = rtol, maxit = maxiters,
118
- delta0 = delta, printerr = show_trace)
116
+ sol = ptcsol (f!, u, FS, FPS, AJ!; atol, rtol, maxit = maxiters,
117
+ delta0 = delta, printerr = ShT)
119
118
end
120
119
end
121
120
122
- if sol. errcode == 0
123
- retcode = ReturnCode. Success
124
- elseif sol. errcode == 10
125
- retcode = ReturnCode. MaxIters
126
- elseif sol. errcode == 1
127
- retcode = ReturnCode. Failure
128
- elseif sol. errcode == - 1
129
- retcode = ReturnCode. Default
130
- end
131
-
132
- # pseudo transient continuation has a fixed cost per iteration, iteration statistics are not interesting here.
133
- stats = method == :pseudotransient ? nothing : (SciMLBase. NLStats (sum (sol. stats. ifun), sum (sol. stats. ijac), 0 , 0 , sum (sol. stats. iarm)))
134
- return SciMLBase. build_solution (prob, alg, sol. solution, sol. history; retcode, stats, original = sol)
121
+ retcode = __siam_fanl_equations_retcode_mapping (sol)
122
+ stats = __siam_fanl_equations_stats_mapping (method, sol)
123
+ return SciMLBase. build_solution (prob, alg, sol. solution, sol. history; retcode, stats,
124
+ original = sol)
135
125
end
136
126
137
- end
127
+ end
0 commit comments