Skip to content

Commit 1530326

Browse files
committed
Propagate stats from MINPACK
1 parent 4602f34 commit 1530326

File tree

3 files changed

+23
-16
lines changed

3 files changed

+23
-16
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "NonlinearSolve"
22
uuid = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
33
authors = ["SciML"]
4-
version = "3.1.0"
4+
version = "3.1.1"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

ext/NonlinearSolveMINPACKExt.jl

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
module NonlinearSolveMINPACKExt
22

3-
using NonlinearSolve, SciMLBase
3+
using NonlinearSolve, DiffEqBase, SciMLBase
44
using MINPACK
55

66
function SciMLBase.__solve(prob::Union{NonlinearProblem{uType, iip},
77
NonlinearLeastSquaresProblem{uType, iip}}, alg::CMINPACK, args...;
88
abstol = 1e-6, maxiters = 100000, alias_u0::Bool = false,
99
termination_condition = nothing, kwargs...) where {uType, iip}
10-
@assert termination_condition===nothing "CMINPACK does not support termination conditions!"
10+
@assert (termination_condition ===
11+
nothing)||(termination_condition isa AbsNormTerminationMode) "CMINPACK does not support termination conditions!"
1112

1213
if prob.u0 isa Number
1314
u0 = [prob.u0]
@@ -57,22 +58,26 @@ function SciMLBase.__solve(prob::Union{NonlinearProblem{uType, iip},
5758
return Cint(0)
5859
end
5960
end
60-
original = MINPACK.fsolve(f!, g!, u0, m; tol = abstol, show_trace, tracing, method,
61-
iterations = maxiters, kwargs...)
61+
original = MINPACK.fsolve(f!, g!, vec(u0), m; tol = abstol, show_trace, tracing,
62+
method, iterations = maxiters, kwargs...)
6263
else
63-
original = MINPACK.fsolve(f!, u0, m; tol = abstol, show_trace, tracing, method,
64-
iterations = maxiters, kwargs...)
64+
original = MINPACK.fsolve(f!, vec(u0), m; tol = abstol, show_trace, tracing,
65+
method, iterations = maxiters, kwargs...)
6566
end
6667

6768
u = reshape(original.x, size(u))
6869
resid = original.f
6970
# retcode = original.converged ? ReturnCode.Success : ReturnCode.Failure
7071
# MINPACK lies about convergence? or maybe uses some other criteria?
7172
# We just check for absolute tolerance on the residual
72-
objective = NonlinearSolve.DEFAULT_NORM(resid)
73+
objective = maximum(abs, resid)
7374
retcode = ifelse(objective abstol, ReturnCode.Success, ReturnCode.Failure)
7475

75-
return SciMLBase.build_solution(prob, alg, u, resid; retcode, original)
76+
# These are only meaningful if `tracing = true`
77+
stats = SciMLBase.NLStats(original.trace.f_calls, original.trace.g_calls,
78+
original.trace.g_calls, original.trace.g_calls, -1)
79+
80+
return SciMLBase.build_solution(prob, alg, u, resid; stats, retcode, original)
7681
end
7782

7883
end

ext/NonlinearSolveNLsolveExt.jl

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@ import UnPack: @unpack
55

66
function SciMLBase.__solve(prob::NonlinearProblem, alg::NLsolveJL, args...; abstol = 1e-6,
77
maxiters = 1000, alias_u0::Bool = false, termination_condition = nothing, kwargs...)
8-
@assert termination_condition===nothing "NLsolveJL does not support termination conditions!"
8+
@assert (termination_condition ===
9+
nothing)||(termination_condition isa AbsNormTerminationMode) "NLsolveJL does not support termination conditions!"
910

1011
if typeof(prob.u0) <: Number
1112
u0 = [prob.u0]
@@ -59,19 +60,20 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::NLsolveJL, args...; abst
5960
end
6061
if prob.f.jac_prototype !== nothing
6162
J = zero(prob.f.jac_prototype)
62-
df = OnceDifferentiable(f!, g!, u0, resid, J)
63+
df = OnceDifferentiable(f!, g!, vec(u0), vec(resid), J)
6364
else
64-
df = OnceDifferentiable(f!, g!, u0, resid)
65+
df = OnceDifferentiable(f!, g!, vec(u0), vec(resid))
6566
end
6667
else
67-
df = OnceDifferentiable(f!, u0, resid; autodiff)
68+
df = OnceDifferentiable(f!, vec(u0), vec(resid); autodiff)
6869
end
6970

70-
original = nlsolve(df, u0; ftol = abstol, iterations = maxiters, method, store_trace,
71-
extended_trace, linesearch, linsolve, factor, autoscale, m, beta, show_trace)
71+
original = nlsolve(df, vec(u0); ftol = abstol, iterations = maxiters, method,
72+
store_trace, extended_trace, linesearch, linsolve, factor, autoscale, m, beta,
73+
show_trace)
7274

7375
u = reshape(original.zero, size(u0))
74-
f!(resid, u)
76+
f!(vec(resid), vec(u))
7577
retcode = original.x_converged || original.f_converged ? ReturnCode.Success :
7678
ReturnCode.Failure
7779
stats = SciMLBase.NLStats(original.f_calls, original.g_calls, original.g_calls,

0 commit comments

Comments
 (0)