Skip to content

Commit 24793ed

Browse files
committed
Consolidate handling of Stats
1 parent b06cb09 commit 24793ed

18 files changed

+139
-185
lines changed

src/NonlinearSolve.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,13 +34,14 @@ using PrecompileTools: @recompile_invalidations, @compile_workload, @setup_workl
3434
istril, istriu, lu, mul!, norm, pinv, tril!, triu!
3535
using LineSearches: LineSearches
3636
using LinearSolve: LinearSolve, LUFactorization, QRFactorization, ComposePreconditioner,
37-
InvPreconditioner, needs_concrete_A
37+
InvPreconditioner, needs_concrete_A, AbstractFactorization,
38+
DefaultAlgorithmChoice, DefaultLinearSolver
3839
using MaybeInplace: @bb
3940
using Printf: @printf
4041
using Preferences: Preferences, @load_preference, @set_preferences!
4142
using RecursiveArrayTools: recursivecopy!, recursivefill!
4243
using SciMLBase: AbstractNonlinearAlgorithm, JacobianWrapper, AbstractNonlinearProblem,
43-
AbstractSciMLOperator, _unwrap_val, has_jac, isinplace
44+
AbstractSciMLOperator, _unwrap_val, has_jac, isinplace, NLStats
4445
using SparseArrays: AbstractSparseMatrix, SparseMatrixCSC
4546
using SparseDiffTools: SparseDiffTools, AbstractSparsityDetection,
4647
ApproximateJacobianSparsity, JacPrototypeSparsityDetection,

src/abstract_types.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,6 @@ abstract type AbstractNonlinearSolveLineSearchCache end
142142

143143
function reinit_cache!(
144144
cache::AbstractNonlinearSolveLineSearchCache, args...; p = cache.p, kwargs...)
145-
cache.nf[] = 0
146145
cache.p = p
147146
end
148147

@@ -235,7 +234,7 @@ function __show_cache(io::IO, cache::AbstractNonlinearSolveCache, indent = 0)
235234
println(io, (" "^(indent + 4)) * "u = ", get_u(cache), ",")
236235
println(io, (" "^(indent + 4)) * "residual = ", get_fu(cache), ",")
237236
println(io, (" "^(indent + 4)) * "inf-norm(residual) = ", norm(get_fu(cache), Inf), ",")
238-
println(io, " "^(indent + 4) * "nsteps = ", get_nsteps(cache), ",")
237+
println(io, " "^(indent + 4) * "nsteps = ", cache.stats.nsteps, ",")
239238
println(io, " "^(indent + 4) * "retcode = ", cache.retcode)
240239
print(io, " "^(indent) * ")")
241240
end

src/core/approximate_jacobian.jl

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ end
9595
inv_workspace
9696

9797
# Counters
98-
nf::Int
98+
stats::NLStats
9999
nsteps::Int
100100
nresets::Int
101101
max_resets::Int
@@ -131,7 +131,7 @@ function __reinit_internal!(cache::ApproximateJacobianSolveCache{INV, GB, iip},
131131
end
132132
cache.p = p
133133

134-
cache.nf = 1
134+
__reinit_internal!(cache.stats)
135135
cache.nsteps = 0
136136
cache.nresets = 0
137137
cache.steps_since_last_reset = 0
@@ -151,8 +151,9 @@ end
151151

152152
function SciMLBase.__init(
153153
prob::AbstractNonlinearProblem{uType, iip}, alg::ApproximateJacobianSolveAlgorithm,
154-
args...; alias_u0 = false, maxtime = nothing, maxiters = 1000, abstol = nothing,
155-
reltol = nothing, linsolve_kwargs = (;), termination_condition = nothing,
154+
args...; stats = empty_nlstats(), alias_u0 = false, maxtime = nothing,
155+
maxiters = 1000, abstol = nothing, reltol = nothing,
156+
linsolve_kwargs = (;), termination_condition = nothing,
156157
internalnorm::F = DEFAULT_NORM, kwargs...) where {uType, iip, F}
157158
timer = get_timer_output()
158159
@static_timeit timer "cache construction" begin
@@ -165,7 +166,7 @@ function SciMLBase.__init(
165166

166167
linsolve = get_linear_solver(alg.descent)
167168
initialization_cache = __internal_init(
168-
prob, alg.initialization, alg, f, fu, u, p; linsolve, maxiters, internalnorm)
169+
prob, alg.initialization, alg, f, fu, u, p; stats, linsolve, maxiters, internalnorm)
169170

170171
abstol, reltol, termination_cache = init_termination_cache(
171172
prob, abstol, reltol, fu, u, termination_condition)
@@ -174,7 +175,7 @@ function SciMLBase.__init(
174175
J = initialization_cache(nothing)
175176
inv_workspace, J = INV ? __safe_inv_workspace(J) : (nothing, J)
176177
descent_cache = __internal_init(
177-
prob, alg.descent, J, fu, u; abstol, reltol, internalnorm,
178+
prob, alg.descent, J, fu, u; stats, abstol, reltol, internalnorm,
178179
linsolve_kwargs, pre_inverted = Val(INV), timer)
179180
du = get_du(descent_cache)
180181

@@ -192,28 +193,28 @@ function SciMLBase.__init(
192193
supports_trust_region(alg.descent) || error("Trust Region not supported by \
193194
$(alg.descent).")
194195
trustregion_cache = __internal_init(
195-
prob, alg.trustregion, f, fu, u, p; internalnorm, kwargs...)
196+
prob, alg.trustregion, f, fu, u, p; stats, internalnorm, kwargs...)
196197
GB = :TrustRegion
197198
end
198199

199200
if alg.linesearch !== missing
200201
supports_line_search(alg.descent) || error("Line Search not supported by \
201202
$(alg.descent).")
202203
linesearch_cache = __internal_init(
203-
prob, alg.linesearch, f, fu, u, p; internalnorm, kwargs...)
204+
prob, alg.linesearch, f, fu, u, p; stats, internalnorm, kwargs...)
204205
GB = :LineSearch
205206
end
206207

207208
update_rule_cache = __internal_init(
208-
prob, alg.update_rule, J, fu, u, du; internalnorm)
209+
prob, alg.update_rule, J, fu, u, du; stats, internalnorm)
209210

210211
trace = init_nonlinearsolve_trace(prob, alg, u, fu, ApplyArray(__zero, J), du;
211212
uses_jacobian_inverse = Val(INV), kwargs...)
212213

213214
return ApproximateJacobianSolveCache{INV, GB, iip, maxtime !== nothing}(
214215
fu, u, u_cache, p, du, J, alg, prob, initialization_cache,
215216
descent_cache, linesearch_cache, trustregion_cache, update_rule_cache,
216-
reinit_rule_cache, inv_workspace, 0, 0, 0, alg.max_resets,
217+
reinit_rule_cache, inv_workspace, stats, 0, 0, alg.max_resets,
217218
maxiters, maxtime, alg.max_shrink_times, 0, timer, 0.0,
218219
termination_cache, trace, ReturnCode.Default, false, false, kwargs)
219220
end
@@ -223,7 +224,7 @@ function __step!(cache::ApproximateJacobianSolveCache{INV, GB, iip};
223224
recompute_jacobian::Union{Nothing, Bool} = nothing) where {INV, GB, iip}
224225
new_jacobian = true
225226
@static_timeit cache.timer "jacobian init/reinit" begin
226-
if get_nsteps(cache) == 0 # First Step is special ignore kwargs
227+
if cache.nsteps == 0 # First Step is special ignore kwargs
227228
J_init = __internal_solve!(
228229
cache.initialization_cache, cache.fu, cache.u, Val(false))
229230
if INV

src/core/generalized_first_order.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ concrete_jac(::GeneralizedFirstOrderAlgorithm{CJ}) where {CJ} = CJ
9797
trustregion_cache
9898

9999
# Counters
100-
nf::Int
100+
stats::NLStats
101101
nsteps::Int
102102
maxiters::Int
103103
maxtime
@@ -135,7 +135,7 @@ function __reinit_internal!(
135135
end
136136
cache.p = p
137137

138-
cache.nf = 1
138+
__reinit_internal!(cache.stats)
139139
cache.nsteps = 0
140140
cache.maxiters = maxiters
141141
cache.maxtime = maxtime
@@ -153,7 +153,7 @@ end
153153

154154
function SciMLBase.__init(
155155
prob::AbstractNonlinearProblem{uType, iip}, alg::GeneralizedFirstOrderAlgorithm,
156-
args...; alias_u0 = false, maxiters = 1000, abstol = nothing,
156+
args...; stats=empty_nlstats(), alias_u0 = false, maxiters = 1000, abstol = nothing,
157157
reltol = nothing, maxtime = nothing, termination_condition = nothing,
158158
internalnorm = DEFAULT_NORM, linsolve_kwargs = (;), kwargs...) where {uType, iip}
159159
timer = get_timer_output()
@@ -170,10 +170,10 @@ function SciMLBase.__init(
170170
linsolve_kwargs = merge((; abstol, reltol), linsolve_kwargs)
171171

172172
jac_cache = JacobianCache(
173-
prob, alg, f, fu, u, p; autodiff = alg.jacobian_ad, linsolve,
173+
prob, alg, f, fu, u, p; stats, autodiff = alg.jacobian_ad, linsolve,
174174
jvp_autodiff = alg.forward_ad, vjp_autodiff = alg.reverse_ad)
175175
J = jac_cache(nothing)
176-
descent_cache = __internal_init(prob, alg.descent, J, fu, u; abstol, reltol,
176+
descent_cache = __internal_init(prob, alg.descent, J, fu, u; stats, abstol, reltol,
177177
internalnorm, linsolve_kwargs, timer)
178178
du = get_du(descent_cache)
179179

@@ -189,15 +189,15 @@ function SciMLBase.__init(
189189
supports_trust_region(alg.descent) || error("Trust Region not supported by \
190190
$(alg.descent).")
191191
trustregion_cache = __internal_init(
192-
prob, alg.trustregion, f, fu, u, p; internalnorm, kwargs...)
192+
prob, alg.trustregion, f, fu, u, p; stats, internalnorm, kwargs...)
193193
GB = :TrustRegion
194194
end
195195

196196
if alg.linesearch !== missing
197197
supports_line_search(alg.descent) || error("Line Search not supported by \
198198
$(alg.descent).")
199199
linesearch_cache = __internal_init(
200-
prob, alg.linesearch, f, fu, u, p; internalnorm, kwargs...)
200+
prob, alg.linesearch, f, fu, u, p; stats, internalnorm, kwargs...)
201201
GB = :LineSearch
202202
end
203203

@@ -206,7 +206,7 @@ function SciMLBase.__init(
206206

207207
return GeneralizedFirstOrderAlgorithmCache{iip, GB, maxtime !== nothing}(
208208
fu, u, u_cache, p, du, J, alg, prob, jac_cache, descent_cache, linesearch_cache,
209-
trustregion_cache, 0, 0, maxiters, maxtime, alg.max_shrink_times, timer,
209+
trustregion_cache, stats, 0, maxiters, maxtime, alg.max_shrink_times, timer,
210210
0.0, true, termination_cache, trace, ReturnCode.Default, false, kwargs)
211211
end
212212
end

src/core/generic.jl

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
function SciMLBase.__solve(prob::Union{NonlinearProblem, NonlinearLeastSquaresProblem},
2-
alg::AbstractNonlinearSolveAlgorithm, args...; kwargs...)
3-
cache = init(prob, alg, args...; kwargs...)
2+
alg::AbstractNonlinearSolveAlgorithm, args...; stats=empty_nlstats(), kwargs...)
3+
cache = SciMLBase.__init(prob, alg, args...; stats, kwargs...)
44
return solve!(cache)
55
end
66

77
function not_terminated(cache::AbstractNonlinearSolveCache)
8-
return !cache.force_stop && get_nsteps(cache) < cache.maxiters
8+
return !cache.force_stop && cache.nsteps < cache.maxiters
99
end
1010

1111
function SciMLBase.solve!(cache::AbstractNonlinearSolveCache)
@@ -16,21 +16,16 @@ function SciMLBase.solve!(cache::AbstractNonlinearSolveCache)
1616
# The solver might have set a different `retcode`
1717
if cache.retcode == ReturnCode.Default
1818
cache.retcode = ifelse(
19-
get_nsteps(cache) cache.maxiters, ReturnCode.MaxIters, ReturnCode.Success)
19+
cache.nsteps cache.maxiters, ReturnCode.MaxIters, ReturnCode.Success)
2020
end
2121

2222
update_from_termination_cache!(cache.termination_cache, cache)
2323

24-
update_trace!(cache.trace, get_nsteps(cache), get_u(cache),
24+
update_trace!(cache.trace, cache.nsteps, get_u(cache),
2525
get_fu(cache), nothing, nothing, nothing; last = True)
2626

2727
return SciMLBase.build_solution(cache.prob, cache.alg, get_u(cache), get_fu(cache);
28-
cache.retcode, stats = __compile_stats(cache), cache.trace)
29-
end
30-
31-
function __compile_stats(cache::AbstractNonlinearSolveCache)
32-
return SciMLBase.NLStats(get_nf(cache), get_njacs(cache), get_nfactors(cache),
33-
get_nsolve(cache), get_nsteps(cache))
28+
cache.retcode, cache.stats, cache.trace)
3429
end
3530

3631
"""
@@ -55,7 +50,8 @@ function SciMLBase.step!(cache::AbstractNonlinearSolveCache{iip, timeit},
5550
__step!(cache, args...; kwargs...)
5651
end
5752

58-
hasfield(typeof(cache), :nsteps) && (cache.nsteps += 1)
53+
cache.stats.nsteps += 1
54+
cache.nsteps += 1
5955

6056
if timeit
6157
cache.total_time += time() - time_start

src/core/spectral_methods.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ concrete_jac(::GeneralizedDFSane) = nothing
6060
linesearch_cache
6161

6262
# Counters
63-
nf::Int
63+
stats::NLStats
6464
nsteps::Int
6565
maxiters::Int
6666
maxtime
@@ -106,7 +106,7 @@ function __reinit_internal!(
106106

107107
reset!(cache.trace)
108108
reinit!(cache.termination_cache, get_fu(cache), get_u(cache); kwargs...)
109-
cache.nf = 1
109+
__reinit_internal!(cache.stats)
110110
cache.nsteps = 0
111111
cache.maxiters = maxiters
112112
cache.maxtime = maxtime
@@ -116,9 +116,9 @@ end
116116

117117
@internal_caches GeneralizedDFSaneCache :linesearch_cache
118118

119-
function SciMLBase.__init(prob::AbstractNonlinearProblem, alg::GeneralizedDFSane,
120-
args...; alias_u0 = false, maxiters = 1000, abstol = nothing,
121-
reltol = nothing, termination_condition = nothing,
119+
function SciMLBase.__init(prob::AbstractNonlinearProblem, alg::GeneralizedDFSane, args...;
120+
stats = empty_nlstats(), alias_u0 = false, maxiters = 1000,
121+
abstol = nothing, reltol = nothing, termination_condition = nothing,
122122
internalnorm::F = DEFAULT_NORM, maxtime = nothing, kwargs...) where {F}
123123
timer = get_timer_output()
124124
@static_timeit timer "cache construction" begin
@@ -130,8 +130,8 @@ function SciMLBase.__init(prob::AbstractNonlinearProblem, alg::GeneralizedDFSane
130130
fu = evaluate_f(prob, u)
131131
@bb fu_cache = copy(fu)
132132

133-
linesearch_cache = __internal_init(
134-
prob, alg.linesearch, prob.f, fu, u, prob.p; maxiters, internalnorm, kwargs...)
133+
linesearch_cache = __internal_init(prob, alg.linesearch, prob.f, fu, u, prob.p;
134+
stats, maxiters, internalnorm, kwargs...)
135135

136136
abstol, reltol, tc_cache = init_termination_cache(
137137
prob, abstol, reltol, fu, u_cache, termination_condition)
@@ -150,7 +150,7 @@ function SciMLBase.__init(prob::AbstractNonlinearProblem, alg::GeneralizedDFSane
150150

151151
return GeneralizedDFSaneCache{isinplace(prob), maxtime !== nothing}(
152152
fu, fu_cache, u, u_cache, prob.p, du, alg, prob, σ_n, T(alg.σ_min),
153-
T(alg.σ_max), linesearch_cache, 0, 0, maxiters, maxtime,
153+
T(alg.σ_max), linesearch_cache, stats, 0, maxiters, maxtime,
154154
timer, 0.0, tc_cache, trace, ReturnCode.Default, false, kwargs)
155155
end
156156
end

src/default.jl

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ end
5959
best::Int
6060
current::Int
6161
nsteps::Int
62+
stats::NLStats
6263
total_time::Float64
6364
maxtime
6465
retcode::ReturnCode.T
@@ -90,6 +91,7 @@ end
9091
function reinit_cache!(cache::NonlinearSolvePolyAlgorithmCache, args...; kwargs...)
9192
foreach(c -> reinit_cache!(c, args...; kwargs...), cache.caches)
9293
cache.current = cache.alg.start_index
94+
__reinit_internal!(cache.stats)
9395
cache.nsteps = 0
9496
cache.total_time = 0.0
9597
end
@@ -98,8 +100,8 @@ for (probType, pType) in ((:NonlinearProblem, :NLS), (:NonlinearLeastSquaresProb
98100
algType = NonlinearSolvePolyAlgorithm{pType}
99101
@eval begin
100102
function SciMLBase.__init(
101-
prob::$probType, alg::$algType{N}, args...; maxtime = nothing,
102-
maxiters = 1000, internalnorm = DEFAULT_NORM,
103+
prob::$probType, alg::$algType{N}, args...; stats = empty_nlstats(),
104+
maxtime = nothing, maxiters = 1000, internalnorm = DEFAULT_NORM,
103105
alias_u0 = false, verbose = true, kwargs...) where {N}
104106
if (alias_u0 && !ismutable(prob.u0))
105107
verbose && @warn "`alias_u0` has been set to `true`, but `u0` is \
@@ -115,13 +117,14 @@ for (probType, pType) in ((:NonlinearProblem, :NLS), (:NonlinearLeastSquaresProb
115117
alias_u0 && (prob = remake(prob; u0 = u0_aliased))
116118
return NonlinearSolvePolyAlgorithmCache{isinplace(prob), N, maxtime !== nothing}(
117119
map(
118-
solver -> SciMLBase.__init(prob, solver, args...; maxtime,
120+
solver -> SciMLBase.__init(prob, solver, args...; stats, maxtime,
119121
internalnorm, alias_u0, verbose, kwargs...),
120122
alg.algs),
121123
alg,
122124
-1,
123125
alg.start_index,
124126
0,
127+
stats,
125128
0.0,
126129
maxtime,
127130
ReturnCode.Default,
@@ -181,7 +184,6 @@ end
181184
push!(calls, quote
182185
fus = tuple($(Tuple(resids)...))
183186
minfu, idx = __findmin(cache.internalnorm, fus)
184-
stats = __compile_stats(cache.caches[idx])
185187
end)
186188
for i in 1:N
187189
push!(calls, quote
@@ -203,7 +205,7 @@ end
203205
end
204206
return __build_solution_less_specialize(
205207
cache.caches[idx].prob, cache.alg, u, fus[idx];
206-
retcode, stats, cache.caches[idx].trace)
208+
retcode, stats = cache.stats, cache.caches[idx].trace)
207209
end)
208210

209211
return Expr(:block, calls...)
@@ -250,7 +252,8 @@ end
250252
for (probType, pType) in ((:NonlinearProblem, :NLS), (:NonlinearLeastSquaresProblem, :NLLS))
251253
algType = NonlinearSolvePolyAlgorithm{pType}
252254
@eval begin
253-
@generated function SciMLBase.__solve(prob::$probType, alg::$algType{N}, args...;
255+
@generated function SciMLBase.__solve(
256+
prob::$probType, alg::$algType{N}, args...; stats = empty_nlstats(),
254257
alias_u0 = false, verbose = true, kwargs...) where {N}
255258
sol_syms = [gensym("sol") for _ in 1:N]
256259
prob_syms = [gensym("prob") for _ in 1:N]
@@ -280,8 +283,9 @@ for (probType, pType) in ((:NonlinearProblem, :NLS), (:NonlinearLeastSquaresProb
280283
else
281284
$(prob_syms[i]) = prob
282285
end
283-
$(cur_sol) = SciMLBase.__solve($(prob_syms[i]), alg.algs[$(i)],
284-
args...; alias_u0, verbose, kwargs...)
286+
$(cur_sol) = SciMLBase.__solve(
287+
$(prob_syms[i]), alg.algs[$(i)], args...;
288+
stats, alias_u0, verbose, kwargs...)
285289
if SciMLBase.successful_retcode($(cur_sol))
286290
if alias_u0
287291
copyto!(u0, $(cur_sol).u)

0 commit comments

Comments
 (0)