Skip to content

Commit ce52260

Browse files
committed
Fix stats
Signed-off-by: ErikQQY <2283984853@qq.com>
1 parent ddfe753 commit ce52260

File tree

1 file changed

+26
-27
lines changed

1 file changed

+26
-27
lines changed

ext/NonlinearSolveSIAMFANLEquationsExt.jl

Lines changed: 26 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -25,24 +25,24 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::SIAMFANLEquationsJL, arg
2525
end
2626

2727
if method == :newton
28-
res = nsolsc(f!, prob.u0; maxit = maxiters, atol = abstol, rtol = reltol, printerr = show_trace)
28+
sol = nsolsc(f!, prob.u0; maxit = maxiters, atol = abstol, rtol = reltol, printerr = show_trace)
2929
elseif method == :pseudotransient
30-
res = ptcsolsc(f!, prob.u0; delta0 = delta, maxit = maxiters, atol = abstol, rtol=reltol, printerr = show_trace)
30+
sol = ptcsolsc(f!, prob.u0; delta0 = delta, maxit = maxiters, atol = abstol, rtol=reltol, printerr = show_trace)
3131
elseif method == :secant
32-
res = secant(f!, prob.u0; maxit = maxiters, atol = abstol, rtol = reltol, printerr = show_trace)
32+
sol = secant(f!, prob.u0; maxit = maxiters, atol = abstol, rtol = reltol, printerr = show_trace)
3333
end
3434

35-
if res.errcode == 0
35+
if sol.errcode == 0
3636
retcode = ReturnCode.Success
37-
elseif res.errcode == 10
37+
elseif sol.errcode == 10
3838
retcode = ReturnCode.MaxIters
39-
elseif res.errcode == 1
39+
elseif sol.errcode == 1
4040
retcode = ReturnCode.Failure
41-
elseif res.errcode == -1
41+
elseif sol.errcode == -1
4242
retcode = ReturnCode.Default
4343
end
44-
stats = method == :pseudotransient ? nothing : (SciMLBase.NLStats(res.stats.ifun[1], res.stats.ijac[1], 0, 0, res.stats.iarm[1]))
45-
return SciMLBase.build_solution(prob, alg, res.solution, res.history; retcode, stats)
44+
stats = method == :pseudotransient ? nothing : (SciMLBase.NLStats(sum(sol.stats.ifun), sum(sol.stats.ijac), 0, 0, sum(sol.stats.iarm)))
45+
return SciMLBase.build_solution(prob, alg, sol.solution, sol.history; retcode, stats, original = sol)
4646
else
4747
u = NonlinearSolve.__maybe_unaliased(prob.u0, alias_u0)
4848
end
@@ -74,22 +74,22 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::SIAMFANLEquationsJL, arg
7474
linsolve_alg = String(linsolve)
7575

7676
if method == :newton
77-
res = nsoli(f!, u, FS, JVS; lsolver = linsolve_alg, maxit = maxiters, atol = abstol, rtol = reltol, printerr = show_trace)
77+
sol = nsoli(f!, u, FS, JVS; lsolver = linsolve_alg, maxit = maxiters, atol = abstol, rtol = reltol, printerr = show_trace)
7878
elseif method == :pseudotransient
79-
res = ptcsoli(f!, u, FS, JVS; lsolver = linsolve_alg, maxit = maxiters, atol = abstol, rtol = reltol, printerr = show_trace)
79+
sol = ptcsoli(f!, u, FS, JVS; lsolver = linsolve_alg, maxit = maxiters, atol = abstol, rtol = reltol, printerr = show_trace)
8080
end
8181

82-
if res.errcode == 0
82+
if sol.errcode == 0
8383
retcode = ReturnCode.Success
84-
elseif res.errcode == 10
84+
elseif sol.errcode == 10
8585
retcode = ReturnCode.MaxIters
86-
elseif res.errcode == 1
86+
elseif sol.errcode == 1
8787
retcode = ReturnCode.Failure
88-
elseif res.errcode == -1
88+
elseif sol.errcode == -1
8989
retcode = ReturnCode.Default
9090
end
91-
stats = method == :pseudotransient ? nothing : (SciMLBase.NLStats(res.stats.ifun[1], res.stats.ijac[1], 0, 0, res.stats.iarm[1]))
92-
return SciMLBase.build_solution(prob, alg, res.solution, res.history; retcode, stats)
91+
stats = method == :pseudotransient ? nothing : (SciMLBase.NLStats(sum(sol.stats.ifun), sum(sol.stats.ijac), 0, 0, sum(sol.stats.iarm)))
92+
return SciMLBase.build_solution(prob, alg, sol.solution, sol.history; retcode, stats, original = sol)
9393
end
9494

9595
if prob.f.jac === nothing
@@ -143,30 +143,29 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::SIAMFANLEquationsJL, arg
143143
AJ!(J, u, x) = J!(J, x, prob.p)
144144

145145
if method == :newton
146-
res = nsol(f!, u, FS, FPS, AJ!;
146+
sol = nsol(f!, u, FS, FPS, AJ!;
147147
sham=1, rtol = reltol, atol = abstol, maxit = maxiters,
148148
printerr = show_trace)
149149
elseif method == :pseudotransient
150-
res = ptcsol(f!, u, FS, FPS, AJ!;
150+
sol = ptcsol(f!, u, FS, FPS, AJ!;
151151
rtol = reltol, atol = abstol, maxit = maxiters,
152152
delta0 = delta, printerr = show_trace)
153-
154153
end
155154

156-
if res.errcode == 0
155+
if sol.errcode == 0
157156
retcode = ReturnCode.Success
158-
elseif res.errcode == 10
157+
elseif sol.errcode == 10
159158
retcode = ReturnCode.MaxIters
160-
elseif res.errcode == 1
159+
elseif sol.errcode == 1
161160
retcode = ReturnCode.Failure
162-
elseif res.errcode == -1
161+
elseif sol.errcode == -1
163162
retcode = ReturnCode.Default
164163
end
165164

166-
167165
# pseudo transient continuation has a fixed cost per iteration, iteration statistics are not interesting here.
168-
stats = method == :pseudotransient ? nothing : (SciMLBase.NLStats(res.stats.ifun[1], res.stats.ijac[1], 0, 0, res.stats.iarm[1]))
169-
return SciMLBase.build_solution(prob, alg, res.solution, res.history; retcode, stats)
166+
stats = method == :pseudotransient ? nothing : (SciMLBase.NLStats(sum(sol.stats.ifun), sum(sol.stats.ijac), 0, 0, sum(sol.stats.iarm)))
167+
println(sol.stats)
168+
return SciMLBase.build_solution(prob, alg, sol.solution, sol.history; retcode, stats, original = sol)
170169
end
171170

172171
end

0 commit comments

Comments
 (0)