1
+ module NonlinearSolveSIAMFANLEquationsExt
2
+
3
+ using NonlinearSolve, SciMLBase
4
+ using SIAMFANLEquations
5
+ import ConcreteStructs: @concrete
6
+ import UnPack: @unpack
7
+ import FiniteDiff, ForwardDiff
8
+
9
+ function SciMLBase. __solve (prob:: NonlinearProblem , alg:: SIAMFANLEquationsJL , args... ; abstol = 1e-8 ,
10
+ reltol = 1e-8 , alias_u0:: Bool = false , maxiters = 1000 , kwargs... )
11
+ @unpack method, autodiff, show_trace, delta, linsolve = alg
12
+
13
+ iip = SciMLBase. isinplace (prob)
14
+ if typeof (prob. u0) <: Number
15
+ f! = if iip
16
+ function (u)
17
+ du = similar (u)
18
+ prob. f (du, u, prob. p)
19
+ return du
20
+ end
21
+ else
22
+ u -> prob. f (u, prob. p)
23
+ end
24
+
25
+ if method == :newton
26
+ res = nsolsc (f!, prob. u0; maxit = maxiters, atol = abstol, rtol = reltol, printerr = show_trace)
27
+ elseif method == :pseudotransient
28
+ res = ptcsolsc (f!, prob. u0; delta0 = delta, maxit = maxiters, atol = abstol, rtol= reltol, printerr = show_trace)
29
+ elseif method == :secant
30
+ res = secant (f!, prob. u0; maxit = maxiters, atol = abstol, rtol = reltol, printerr = show_trace)
31
+ end
32
+
33
+ if res. errcode == 0
34
+ retcode = ReturnCode. Success
35
+ elseif res. errcode == 10
36
+ retcode = ReturnCode. MaxIters
37
+ elseif res. errcode == 1
38
+ retcode = ReturnCode. Failure
39
+ @error (" Line search failed" )
40
+ elseif res. errcode == - 1
41
+ retcode = ReturnCode. Default
42
+ @info (" Initial iterate satisfies the termination criteria" )
43
+ end
44
+ stats = method == :pseudotransient ? nothing : (SciMLBase. NLStats (res. stats. ifun[1 ], res. stats. ijac[1 ], - 1 , - 1 , res. stats. iarm[1 ]))
45
+ return SciMLBase. build_solution (prob, alg, res. solution, res. history; retcode, stats)
46
+ else
47
+ u = NonlinearSolve. __maybe_unaliased (prob. u0, alias_u0)
48
+ end
49
+
50
+ fu = NonlinearSolve. evaluate_f (prob, u)
51
+
52
+ if iip
53
+ f! = function (du, u)
54
+ prob. f (du, u, prob. p)
55
+ return du
56
+ end
57
+ else
58
+ f! = function (du, u)
59
+ du .= prob. f (u, prob. p)
60
+ return du
61
+ end
62
+ end
63
+
64
+ # Allocate ahead for function and Jacobian
65
+ N = length (u)
66
+ FS = zeros (eltype (u), N)
67
+ FPS = zeros (eltype (u), N, N)
68
+ # Allocate ahead for Krylov basis
69
+
70
+ # Jacobian free Newton Krylov
71
+ if linsolve != = nothing
72
+ JVS = linsolve == :gmres ? zeros (eltype (u), N, 3 ) : zeros (eltype (u), N)
73
+ # `linsolve` as a Symbol to keep unified interface with other EXTs, SIAMFANLEquations directly use String to choose between linear solvers
74
+ linsolve_alg = strip (repr (linsolve), ' :' )
75
+
76
+ if method == :newton
77
+ res = nsoli (f!, u, FS, JVS; lsolver = linsolve_alg, maxit = maxiters, atol = abstol, rtol = reltol, printerr = show_trace)
78
+ elseif method == :pseudotransient
79
+ res = ptcsoli (f!, u, FS, JVS; lsolver = linsolve_alg, maxit = maxiters, atol = abstol, rtol = reltol, printerr = show_trace)
80
+ end
81
+
82
+ if res. errcode == 0
83
+ retcode = ReturnCode. Success
84
+ elseif res. errcode == 10
85
+ retcode = ReturnCode. MaxIters
86
+ elseif res. errcode == 1
87
+ retcode = ReturnCode. Failure
88
+ @error (" Line search failed" )
89
+ elseif res. errcode == - 1
90
+ retcode = ReturnCode. Default
91
+ @info (" Initial iterate satisfies the termination criteria" )
92
+ end
93
+ stats = method == :pseudotransient ? nothing : (SciMLBase. NLStats (res. stats. ifun[1 ], res. stats. ijac[1 ], - 1 , - 1 , res. stats. iarm[1 ]))
94
+ return SciMLBase. build_solution (prob, alg, res. solution, res. history; retcode, stats)
95
+ end
96
+
97
+ if prob. f. jac === nothing
98
+ use_forward_diff = if alg. autodiff === nothing
99
+ ForwardDiff. can_dual (eltype (u))
100
+ else
101
+ alg. autodiff isa AutoForwardDiff
102
+ end
103
+ uf = SciMLBase. JacobianWrapper {iip} (prob. f, prob. p)
104
+ if use_forward_diff
105
+ cache = iip ? ForwardDiff. JacobianConfig (uf, fu, u) :
106
+ ForwardDiff. JacobianConfig (uf, u)
107
+ else
108
+ cache = FiniteDiff. JacobianCache (u, fu)
109
+ end
110
+ J! = if iip
111
+ if use_forward_diff
112
+ fu_cache = similar (fu)
113
+ function (J, x, p)
114
+ uf. p = p
115
+ ForwardDiff. jacobian! (J, uf, fu_cache, x, cache)
116
+ return J
117
+ end
118
+ else
119
+ function (J, x, p)
120
+ uf. p = p
121
+ FiniteDiff. finite_difference_jacobian! (J, uf, x, cache)
122
+ return J
123
+ end
124
+ end
125
+ else
126
+ if use_forward_diff
127
+ function (J, x, p)
128
+ uf. p = p
129
+ ForwardDiff. jacobian! (J, uf, x, cache)
130
+ return J
131
+ end
132
+ else
133
+ function (J, x, p)
134
+ uf. p = p
135
+ J_ = FiniteDiff. finite_difference_jacobian (uf, x, cache)
136
+ copyto! (J, J_)
137
+ return J
138
+ end
139
+ end
140
+ end
141
+ else
142
+ J! = prob. f. jac
143
+ end
144
+
145
+ AJ! (J, u, x) = J! (J, x, prob. p)
146
+
147
+ if method == :newton
148
+ res = nsol (f!, u, FS, FPS, AJ!;
149
+ sham= 1 , rtol = reltol, atol = abstol, maxit = maxiters,
150
+ printerr = show_trace)
151
+ elseif method == :pseudotransient
152
+ res = ptcsol (f!, u, FS, FPS, AJ!;
153
+ rtol = reltol, atol = abstol, maxit = maxiters,
154
+ delta0 = delta, printerr = show_trace)
155
+
156
+ end
157
+
158
+ if res. errcode == 0
159
+ retcode = ReturnCode. Success
160
+ elseif res. errcode == 10
161
+ retcode = ReturnCode. MaxIters
162
+ elseif res. errcode == 1
163
+ retcode = ReturnCode. Failure
164
+ @error (" Line search failed" )
165
+ elseif res. errcode == - 1
166
+ retcode = ReturnCode. Default
167
+ @info (" Initial iterate satisfies the termination criteria" )
168
+ end
169
+
170
+
171
+ # pseudo transient continuation has a fixed cost per iteration, iteration statistics are not interesting here.
172
+ stats = method == :pseudotransient ? nothing : (SciMLBase. NLStats (res. stats. ifun[1 ], res. stats. ijac[1 ], - 1 , - 1 , res. stats. iarm[1 ]))
173
+ return SciMLBase. build_solution (prob, alg, res. solution, res. history; retcode, stats)
174
+ end
175
+
176
+ end
0 commit comments