Skip to content

Commit 630aad4

Browse files
authored
Merge pull request #437 from SciML/ap/consolidate_stats
Consolidate Stats Handling
2 parents 6f65ac4 + 50aa5c4 commit 630aad4

21 files changed

+161
-247
lines changed

Project.toml

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@ TimerOutputs = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"
3232

3333
[weakdeps]
3434
BandedMatrices = "aae01518-5342-5314-be14-df237901396f"
35-
Enlsip = "d5306a6b-d590-428d-a53a-eb3bb2d36f2d"
3635
FastLevenbergMarquardt = "7a0df574-e128-4d35-8cbd-3d84502bf7ce"
3736
FixedPointAcceleration = "817d07cb-a79a-5c30-9a31-890123675176"
3837
LeastSquaresOptim = "0fc2ff8b-aaa3-5acd-a817-1944a5e08891"
@@ -46,7 +45,6 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
4645

4746
[extensions]
4847
NonlinearSolveBandedMatricesExt = "BandedMatrices"
49-
NonlinearSolveEnlsipExt = "Enlsip"
5048
NonlinearSolveFastLevenbergMarquardtExt = "FastLevenbergMarquardt"
5149
NonlinearSolveFixedPointAccelerationExt = "FixedPointAcceleration"
5250
NonlinearSolveLeastSquaresOptimExt = "LeastSquaresOptim"
@@ -67,7 +65,6 @@ BenchmarkTools = "1.4"
6765
CUDA = "5.2"
6866
ConcreteStructs = "0.2.3"
6967
DiffEqBase = "6.149.0"
70-
Enlsip = "0.9"
7168
Enzyme = "0.12"
7269
ExplicitImports = "1.4.4"
7370
FastBroadcast = "0.2.8, 0.3"
@@ -119,7 +116,6 @@ Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
119116
BandedMatrices = "aae01518-5342-5314-be14-df237901396f"
120117
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
121118
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
122-
Enlsip = "d5306a6b-d590-428d-a53a-eb3bb2d36f2d"
123119
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
124120
ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7"
125121
FastLevenbergMarquardt = "7a0df574-e128-4d35-8cbd-3d84502bf7ce"
@@ -145,4 +141,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
145141
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
146142

147143
[targets]
148-
test = ["Aqua", "BandedMatrices", "BenchmarkTools", "CUDA", "Enlsip", "Enzyme", "ExplicitImports", "FastLevenbergMarquardt", "FixedPointAcceleration", "LeastSquaresOptim", "MINPACK", "ModelingToolkit", "NLSolvers", "NLsolve", "NaNMath", "NonlinearProblemLibrary", "OrdinaryDiffEq", "Pkg", "Random", "ReTestItems", "SIAMFANLEquations", "SpeedMapping", "StableRNGs", "StaticArrays", "Sundials", "Symbolics", "Test", "Zygote"]
144+
test = ["Aqua", "BandedMatrices", "BenchmarkTools", "CUDA", "Enzyme", "ExplicitImports", "FastLevenbergMarquardt", "FixedPointAcceleration", "LeastSquaresOptim", "MINPACK", "ModelingToolkit", "NLSolvers", "NLsolve", "NaNMath", "NonlinearProblemLibrary", "OrdinaryDiffEq", "Pkg", "Random", "ReTestItems", "SIAMFANLEquations", "SpeedMapping", "StableRNGs", "StaticArrays", "Sundials", "Symbolics", "Test", "Zygote"]

docs/src/basics/nonlinear_solution.md

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@ SciMLBase.NonlinearSolution
99

1010
```@docs
1111
SciMLBase.NLStats
12-
NonlinearSolve.ImmutableNLStats
1312
```
1413

1514
## Return Code

src/NonlinearSolve.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,13 +34,14 @@ using PrecompileTools: @recompile_invalidations, @compile_workload, @setup_workl
3434
istril, istriu, lu, mul!, norm, pinv, tril!, triu!
3535
using LineSearches: LineSearches
3636
using LinearSolve: LinearSolve, LUFactorization, QRFactorization, ComposePreconditioner,
37-
InvPreconditioner, needs_concrete_A
37+
InvPreconditioner, needs_concrete_A, AbstractFactorization,
38+
DefaultAlgorithmChoice, DefaultLinearSolver
3839
using MaybeInplace: @bb
3940
using Printf: @printf
4041
using Preferences: Preferences, @load_preference, @set_preferences!
4142
using RecursiveArrayTools: recursivecopy!, recursivefill!
4243
using SciMLBase: AbstractNonlinearAlgorithm, JacobianWrapper, AbstractNonlinearProblem,
43-
AbstractSciMLOperator, _unwrap_val, has_jac, isinplace
44+
AbstractSciMLOperator, _unwrap_val, has_jac, isinplace, NLStats
4445
using SparseArrays: AbstractSparseMatrix, SparseMatrixCSC
4546
using SparseDiffTools: SparseDiffTools, AbstractSparsityDetection,
4647
ApproximateJacobianSparsity, JacPrototypeSparsityDetection,

src/abstract_types.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,6 @@ abstract type AbstractNonlinearSolveLineSearchCache end
142142

143143
function reinit_cache!(
144144
cache::AbstractNonlinearSolveLineSearchCache, args...; p = cache.p, kwargs...)
145-
cache.nf[] = 0
146145
cache.p = p
147146
end
148147

@@ -235,7 +234,7 @@ function __show_cache(io::IO, cache::AbstractNonlinearSolveCache, indent = 0)
235234
println(io, (" "^(indent + 4)) * "u = ", get_u(cache), ",")
236235
println(io, (" "^(indent + 4)) * "residual = ", get_fu(cache), ",")
237236
println(io, (" "^(indent + 4)) * "inf-norm(residual) = ", norm(get_fu(cache), Inf), ",")
238-
println(io, " "^(indent + 4) * "nsteps = ", get_nsteps(cache), ",")
237+
println(io, " "^(indent + 4) * "nsteps = ", cache.stats.nsteps, ",")
239238
println(io, " "^(indent + 4) * "retcode = ", cache.retcode)
240239
print(io, " "^(indent) * ")")
241240
end

src/core/approximate_jacobian.jl

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ end
9595
inv_workspace
9696

9797
# Counters
98-
nf::Int
98+
stats::NLStats
9999
nsteps::Int
100100
nresets::Int
101101
max_resets::Int
@@ -131,7 +131,7 @@ function __reinit_internal!(cache::ApproximateJacobianSolveCache{INV, GB, iip},
131131
end
132132
cache.p = p
133133

134-
cache.nf = 1
134+
__reinit_internal!(cache.stats)
135135
cache.nsteps = 0
136136
cache.nresets = 0
137137
cache.steps_since_last_reset = 0
@@ -151,8 +151,9 @@ end
151151

152152
function SciMLBase.__init(
153153
prob::AbstractNonlinearProblem{uType, iip}, alg::ApproximateJacobianSolveAlgorithm,
154-
args...; alias_u0 = false, maxtime = nothing, maxiters = 1000, abstol = nothing,
155-
reltol = nothing, linsolve_kwargs = (;), termination_condition = nothing,
154+
args...; stats = empty_nlstats(), alias_u0 = false, maxtime = nothing,
155+
maxiters = 1000, abstol = nothing, reltol = nothing,
156+
linsolve_kwargs = (;), termination_condition = nothing,
156157
internalnorm::F = DEFAULT_NORM, kwargs...) where {uType, iip, F}
157158
timer = get_timer_output()
158159
@static_timeit timer "cache construction" begin
@@ -164,18 +165,17 @@ function SciMLBase.__init(
164165
INV = store_inverse_jacobian(alg.update_rule)
165166

166167
linsolve = get_linear_solver(alg.descent)
167-
initialization_cache = __internal_init(
168-
prob, alg.initialization, alg, f, fu, u, p; linsolve, maxiters, internalnorm)
168+
initialization_cache = __internal_init(prob, alg.initialization, alg, f, fu, u, p;
169+
stats, linsolve, maxiters, internalnorm)
169170

170171
abstol, reltol, termination_cache = init_termination_cache(
171172
prob, abstol, reltol, fu, u, termination_condition)
172173
linsolve_kwargs = merge((; abstol, reltol), linsolve_kwargs)
173174

174175
J = initialization_cache(nothing)
175176
inv_workspace, J = INV ? __safe_inv_workspace(J) : (nothing, J)
176-
descent_cache = __internal_init(
177-
prob, alg.descent, J, fu, u; abstol, reltol, internalnorm,
178-
linsolve_kwargs, pre_inverted = Val(INV), timer)
177+
descent_cache = __internal_init(prob, alg.descent, J, fu, u; stats, abstol, reltol,
178+
internalnorm, linsolve_kwargs, pre_inverted = Val(INV), timer)
179179
du = get_du(descent_cache)
180180

181181
reinit_rule_cache = __internal_init(alg.reinit_rule, J, fu, u, du)
@@ -192,28 +192,28 @@ function SciMLBase.__init(
192192
supports_trust_region(alg.descent) || error("Trust Region not supported by \
193193
$(alg.descent).")
194194
trustregion_cache = __internal_init(
195-
prob, alg.trustregion, f, fu, u, p; internalnorm, kwargs...)
195+
prob, alg.trustregion, f, fu, u, p; stats, internalnorm, kwargs...)
196196
GB = :TrustRegion
197197
end
198198

199199
if alg.linesearch !== missing
200200
supports_line_search(alg.descent) || error("Line Search not supported by \
201201
$(alg.descent).")
202202
linesearch_cache = __internal_init(
203-
prob, alg.linesearch, f, fu, u, p; internalnorm, kwargs...)
203+
prob, alg.linesearch, f, fu, u, p; stats, internalnorm, kwargs...)
204204
GB = :LineSearch
205205
end
206206

207207
update_rule_cache = __internal_init(
208-
prob, alg.update_rule, J, fu, u, du; internalnorm)
208+
prob, alg.update_rule, J, fu, u, du; stats, internalnorm)
209209

210210
trace = init_nonlinearsolve_trace(prob, alg, u, fu, ApplyArray(__zero, J), du;
211211
uses_jacobian_inverse = Val(INV), kwargs...)
212212

213213
return ApproximateJacobianSolveCache{INV, GB, iip, maxtime !== nothing}(
214214
fu, u, u_cache, p, du, J, alg, prob, initialization_cache,
215215
descent_cache, linesearch_cache, trustregion_cache, update_rule_cache,
216-
reinit_rule_cache, inv_workspace, 0, 0, 0, alg.max_resets,
216+
reinit_rule_cache, inv_workspace, stats, 0, 0, alg.max_resets,
217217
maxiters, maxtime, alg.max_shrink_times, 0, timer, 0.0,
218218
termination_cache, trace, ReturnCode.Default, false, false, kwargs)
219219
end
@@ -223,7 +223,7 @@ function __step!(cache::ApproximateJacobianSolveCache{INV, GB, iip};
223223
recompute_jacobian::Union{Nothing, Bool} = nothing) where {INV, GB, iip}
224224
new_jacobian = true
225225
@static_timeit cache.timer "jacobian init/reinit" begin
226-
if get_nsteps(cache) == 0 # First Step is special ignore kwargs
226+
if cache.nsteps == 0 # First Step is special ignore kwargs
227227
J_init = __internal_solve!(
228228
cache.initialization_cache, cache.fu, cache.u, Val(false))
229229
if INV

src/core/generalized_first_order.jl

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ concrete_jac(::GeneralizedFirstOrderAlgorithm{CJ}) where {CJ} = CJ
9797
trustregion_cache
9898

9999
# Counters
100-
nf::Int
100+
stats::NLStats
101101
nsteps::Int
102102
maxiters::Int
103103
maxtime
@@ -135,7 +135,7 @@ function __reinit_internal!(
135135
end
136136
cache.p = p
137137

138-
cache.nf = 1
138+
__reinit_internal!(cache.stats)
139139
cache.nsteps = 0
140140
cache.maxiters = maxiters
141141
cache.maxtime = maxtime
@@ -153,9 +153,10 @@ end
153153

154154
function SciMLBase.__init(
155155
prob::AbstractNonlinearProblem{uType, iip}, alg::GeneralizedFirstOrderAlgorithm,
156-
args...; alias_u0 = false, maxiters = 1000, abstol = nothing,
157-
reltol = nothing, maxtime = nothing, termination_condition = nothing,
158-
internalnorm = DEFAULT_NORM, linsolve_kwargs = (;), kwargs...) where {uType, iip}
156+
args...; stats = empty_nlstats(), alias_u0 = false, maxiters = 1000,
157+
abstol = nothing, reltol = nothing, maxtime = nothing,
158+
termination_condition = nothing, internalnorm = DEFAULT_NORM,
159+
linsolve_kwargs = (;), kwargs...) where {uType, iip}
159160
timer = get_timer_output()
160161
@static_timeit timer "cache construction" begin
161162
(; f, u0, p) = prob
@@ -170,11 +171,11 @@ function SciMLBase.__init(
170171
linsolve_kwargs = merge((; abstol, reltol), linsolve_kwargs)
171172

172173
jac_cache = JacobianCache(
173-
prob, alg, f, fu, u, p; autodiff = alg.jacobian_ad, linsolve,
174+
prob, alg, f, fu, u, p; stats, autodiff = alg.jacobian_ad, linsolve,
174175
jvp_autodiff = alg.forward_ad, vjp_autodiff = alg.reverse_ad)
175176
J = jac_cache(nothing)
176-
descent_cache = __internal_init(prob, alg.descent, J, fu, u; abstol, reltol,
177-
internalnorm, linsolve_kwargs, timer)
177+
descent_cache = __internal_init(prob, alg.descent, J, fu, u; stats, abstol,
178+
reltol, internalnorm, linsolve_kwargs, timer)
178179
du = get_du(descent_cache)
179180

180181
if alg.trustregion !== missing && alg.linesearch !== missing
@@ -189,15 +190,15 @@ function SciMLBase.__init(
189190
supports_trust_region(alg.descent) || error("Trust Region not supported by \
190191
$(alg.descent).")
191192
trustregion_cache = __internal_init(
192-
prob, alg.trustregion, f, fu, u, p; internalnorm, kwargs...)
193+
prob, alg.trustregion, f, fu, u, p; stats, internalnorm, kwargs...)
193194
GB = :TrustRegion
194195
end
195196

196197
if alg.linesearch !== missing
197198
supports_line_search(alg.descent) || error("Line Search not supported by \
198199
$(alg.descent).")
199200
linesearch_cache = __internal_init(
200-
prob, alg.linesearch, f, fu, u, p; internalnorm, kwargs...)
201+
prob, alg.linesearch, f, fu, u, p; stats, internalnorm, kwargs...)
201202
GB = :LineSearch
202203
end
203204

@@ -206,7 +207,7 @@ function SciMLBase.__init(
206207

207208
return GeneralizedFirstOrderAlgorithmCache{iip, GB, maxtime !== nothing}(
208209
fu, u, u_cache, p, du, J, alg, prob, jac_cache, descent_cache, linesearch_cache,
209-
trustregion_cache, 0, 0, maxiters, maxtime, alg.max_shrink_times, timer,
210+
trustregion_cache, stats, 0, maxiters, maxtime, alg.max_shrink_times, timer,
210211
0.0, true, termination_cache, trace, ReturnCode.Default, false, kwargs)
211212
end
212213
end

src/core/generic.jl

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
function SciMLBase.__solve(prob::Union{NonlinearProblem, NonlinearLeastSquaresProblem},
2-
alg::AbstractNonlinearSolveAlgorithm, args...; kwargs...)
3-
cache = init(prob, alg, args...; kwargs...)
2+
alg::AbstractNonlinearSolveAlgorithm, args...; stats = empty_nlstats(), kwargs...)
3+
cache = SciMLBase.__init(prob, alg, args...; stats, kwargs...)
44
return solve!(cache)
55
end
66

77
function not_terminated(cache::AbstractNonlinearSolveCache)
8-
return !cache.force_stop && get_nsteps(cache) < cache.maxiters
8+
return !cache.force_stop && cache.nsteps < cache.maxiters
99
end
1010

1111
function SciMLBase.solve!(cache::AbstractNonlinearSolveCache)
@@ -16,21 +16,16 @@ function SciMLBase.solve!(cache::AbstractNonlinearSolveCache)
1616
# The solver might have set a different `retcode`
1717
if cache.retcode == ReturnCode.Default
1818
cache.retcode = ifelse(
19-
get_nsteps(cache) cache.maxiters, ReturnCode.MaxIters, ReturnCode.Success)
19+
cache.nsteps cache.maxiters, ReturnCode.MaxIters, ReturnCode.Success)
2020
end
2121

2222
update_from_termination_cache!(cache.termination_cache, cache)
2323

24-
update_trace!(cache.trace, get_nsteps(cache), get_u(cache),
24+
update_trace!(cache.trace, cache.nsteps, get_u(cache),
2525
get_fu(cache), nothing, nothing, nothing; last = True)
2626

2727
return SciMLBase.build_solution(cache.prob, cache.alg, get_u(cache), get_fu(cache);
28-
cache.retcode, stats = __compile_stats(cache), cache.trace)
29-
end
30-
31-
function __compile_stats(cache::AbstractNonlinearSolveCache)
32-
return ImmutableNLStats(get_nf(cache), get_njacs(cache), get_nfactors(cache),
33-
get_nsolve(cache), get_nsteps(cache))
28+
cache.retcode, cache.stats, cache.trace)
3429
end
3530

3631
"""
@@ -55,7 +50,8 @@ function SciMLBase.step!(cache::AbstractNonlinearSolveCache{iip, timeit},
5550
__step!(cache, args...; kwargs...)
5651
end
5752

58-
hasfield(typeof(cache), :nsteps) && (cache.nsteps += 1)
53+
cache.stats.nsteps += 1
54+
cache.nsteps += 1
5955

6056
if timeit
6157
cache.total_time += time() - time_start

src/core/spectral_methods.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ concrete_jac(::GeneralizedDFSane) = nothing
6060
linesearch_cache
6161

6262
# Counters
63-
nf::Int
63+
stats::NLStats
6464
nsteps::Int
6565
maxiters::Int
6666
maxtime
@@ -106,7 +106,7 @@ function __reinit_internal!(
106106

107107
reset!(cache.trace)
108108
reinit!(cache.termination_cache, get_fu(cache), get_u(cache); kwargs...)
109-
cache.nf = 1
109+
__reinit_internal!(cache.stats)
110110
cache.nsteps = 0
111111
cache.maxiters = maxiters
112112
cache.maxtime = maxtime
@@ -116,9 +116,9 @@ end
116116

117117
@internal_caches GeneralizedDFSaneCache :linesearch_cache
118118

119-
function SciMLBase.__init(prob::AbstractNonlinearProblem, alg::GeneralizedDFSane,
120-
args...; alias_u0 = false, maxiters = 1000, abstol = nothing,
121-
reltol = nothing, termination_condition = nothing,
119+
function SciMLBase.__init(prob::AbstractNonlinearProblem, alg::GeneralizedDFSane, args...;
120+
stats = empty_nlstats(), alias_u0 = false, maxiters = 1000,
121+
abstol = nothing, reltol = nothing, termination_condition = nothing,
122122
internalnorm::F = DEFAULT_NORM, maxtime = nothing, kwargs...) where {F}
123123
timer = get_timer_output()
124124
@static_timeit timer "cache construction" begin
@@ -130,8 +130,8 @@ function SciMLBase.__init(prob::AbstractNonlinearProblem, alg::GeneralizedDFSane
130130
fu = evaluate_f(prob, u)
131131
@bb fu_cache = copy(fu)
132132

133-
linesearch_cache = __internal_init(
134-
prob, alg.linesearch, prob.f, fu, u, prob.p; maxiters, internalnorm, kwargs...)
133+
linesearch_cache = __internal_init(prob, alg.linesearch, prob.f, fu, u, prob.p;
134+
stats, maxiters, internalnorm, kwargs...)
135135

136136
abstol, reltol, tc_cache = init_termination_cache(
137137
prob, abstol, reltol, fu, u_cache, termination_condition)
@@ -150,7 +150,7 @@ function SciMLBase.__init(prob::AbstractNonlinearProblem, alg::GeneralizedDFSane
150150

151151
return GeneralizedDFSaneCache{isinplace(prob), maxtime !== nothing}(
152152
fu, fu_cache, u, u_cache, prob.p, du, alg, prob, σ_n, T(alg.σ_min),
153-
T(alg.σ_max), linesearch_cache, 0, 0, maxiters, maxtime,
153+
T(alg.σ_max), linesearch_cache, stats, 0, maxiters, maxtime,
154154
timer, 0.0, tc_cache, trace, ReturnCode.Default, false, kwargs)
155155
end
156156
end

0 commit comments

Comments
 (0)