Skip to content

Commit 267dd3f

Browse files
committed
Track down performance pitfalls and purge dynamic dispatches and allocations
1 parent 808512b commit 267dd3f

20 files changed

+385
-306
lines changed

Manifest.toml

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
julia_version = "1.10.0"
44
manifest_format = "2.0"
5-
project_hash = "2fa62d6199f8a6cecd1dab4dd969e2f9c3e4eb5d"
5+
project_hash = "ee8f38812d75ecf5b51425c9f9559c9e53418c46"
66

77
[[deps.ADTypes]]
88
git-tree-sha1 = "41c37aa88889c171f1300ceac1313c06e891d245"
@@ -894,12 +894,6 @@ deps = ["Artifacts", "Libdl", "libblastrampoline_jll"]
894894
uuid = "bea87d4a-7f5b-5778-9afe-8cc45184846c"
895895
version = "7.2.1+1"
896896

897-
[[deps.SumTypes]]
898-
deps = ["MacroTools"]
899-
git-tree-sha1 = "dc8ae794496a9f04e16393612511223750291547"
900-
uuid = "8e1ec7a9-0e02-4297-b0fe-6433085c89f2"
901-
version = "0.5.1"
902-
903897
[[deps.SymbolicIndexingInterface]]
904898
git-tree-sha1 = "74502f408d99fc217a9d7cd901d9ffe45af892b1"
905899
uuid = "2efcf032-c050-4f8e-a9bb-153293bab1f5"

Project.toml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1818
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
1919
MaybeInplace = "bb5d69b7-63fc-4a16-80bd-7e42200c7bdb"
2020
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
21+
Preferences = "21216c6a-2e73-6563-6e65-726566657250"
2122
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
2223
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
2324
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
@@ -26,7 +27,6 @@ SimpleNonlinearSolve = "727e6d20-b764-4bd8-a329-72de5adea6c7"
2627
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
2728
SparseDiffTools = "47a9eef4-7e08-11e9-0b38-333d64bd3804"
2829
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
29-
SumTypes = "8e1ec7a9-0e02-4297-b0fe-6433085c89f2"
3030
TimerOutputs = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"
3131

3232
[weakdeps]
@@ -95,7 +95,6 @@ SpeedMapping = "0.3"
9595
StableRNGs = "1"
9696
StaticArrays = "1.7"
9797
StaticArraysCore = "1.4"
98-
SumTypes = "0.5"
9998
Sundials = "4.23.1"
10099
Symbolics = "5.13"
101100
Test = "1"

docs/src/basics/solve.md

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,13 @@ solve(prob::SciMLBase.NonlinearProblem, args...; kwargs...)
1414
## Iteration Controls
1515

1616
- `maxiters::Int`: The maximum number of iterations to perform. Defaults to `1000`.
17-
- `maxtime`: The maximum time for solving the nonlinear system of equations. Defaults to `Inf`.
18-
- `abstol::Number`: The absolute tolerance. Defaults to `real(oneunit(T)) * (eps(real(one(T))))^(4 // 5)`.
19-
- `reltol::Number`: The relative tolerance. Defaults to `real(oneunit(T)) * (eps(real(one(T))))^(4 // 5)`.
17+
- `maxtime`: The maximum time for solving the nonlinear system of equations. Defaults to
18+
`nothing` which means no time limit. Note that setting a time limit does have a small
19+
overhead.
20+
- `abstol::Number`: The absolute tolerance. Defaults to
21+
`real(oneunit(T)) * (eps(real(one(T))))^(4 // 5)`.
22+
- `reltol::Number`: The relative tolerance. Defaults to
23+
`real(oneunit(T)) * (eps(real(one(T))))^(4 // 5)`.
2024
- `termination_condition`: Termination Condition from DiffEqBase. Defaults to
2125
`AbsSafeBestTerminationMode()` for `NonlinearSolve.jl` and `AbsTerminateMode()` for
2226
`SimpleNonlinearSolve.jl`.

docs/src/tutorials/code_optimization.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ to normal array expressions, for example:
115115
```@example small_opt
116116
using StaticArrays
117117
A = SA[2.0, 3.0, 5.0]
118-
typeof(A) # SVector{3, Float64} (alias for SArray{Tuple{3}, Float64, 1, 3})
118+
typeof(A)
119119
```
120120

121121
Notice that the `3` after `SVector` gives the size of the `SVector`. It cannot be changed.

docs/src/tutorials/getting_started.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -194,8 +194,8 @@ solve(prob, GaussNewton(), reltol = 1e-12, abstol = 1e-12)
194194

195195
## Going Beyond the Basics: How to use the Documentation
196196

197-
Congrats, you now know how to use the basics of NonlinearSolve.jl! However, there is so much more to
198-
see. Next check out:
197+
Congrats, you now know how to use the basics of NonlinearSolve.jl! However, there is so much
198+
more to see. Next check out:
199199

200200
- [Some code optimization tricks to know about with NonlinearSolve.jl](@ref code_optimization)
201201
- [An iterator interface which lets you step through the solving process step by step](@ref iterator)

src/NonlinearSolve.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@ import PrecompileTools: @recompile_invalidations, @compile_workload, @setup_work
99

1010
@recompile_invalidations begin
1111
using ADTypes, ConcreteStructs, DiffEqBase, FastBroadcast, FastClosures, LazyArrays,
12-
LineSearches, LinearAlgebra, LinearSolve, MaybeInplace, Printf, SciMLBase,
13-
SimpleNonlinearSolve, SparseArrays, SparseDiffTools, SumTypes, TimerOutputs
12+
LineSearches, LinearAlgebra, LinearSolve, MaybeInplace, Preferences, Printf,
13+
SciMLBase, SimpleNonlinearSolve, SparseArrays, SparseDiffTools, TimerOutputs
1414

1515
import ArrayInterface: undefmatrix, can_setindex, restructure, fast_scalar_indexing
1616
import DiffEqBase: AbstractNonlinearTerminationMode,
@@ -40,6 +40,7 @@ const True = Val(true)
4040
const False = Val(false)
4141

4242
include("abstract_types.jl")
43+
include("timer_outputs.jl")
4344
include("internal/helpers.jl")
4445

4546
include("descent/newton.jl")

src/abstract_types.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -140,11 +140,11 @@ abstract type AbstractNonlinearSolveExtensionAlgorithm <:
140140
AbstractNonlinearSolveAlgorithm{:Extension} end
141141

142142
"""
143-
AbstractNonlinearSolveCache{iip}
143+
AbstractNonlinearSolveCache{iip, timeit}
144144
145145
Abstract Type for all NonlinearSolve.jl Caches.
146146
"""
147-
abstract type AbstractNonlinearSolveCache{iip} end
147+
abstract type AbstractNonlinearSolveCache{iip, timeit} end
148148

149149
SciMLBase.isinplace(::AbstractNonlinearSolveCache{iip}) where {iip} = iip
150150

src/core/approximate_jacobian.jl

Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,8 @@ end
4646

4747
@inline concrete_jac(::ApproximateJacobianSolveAlgorithm{CJ}) where {CJ} = CJ
4848

49-
@concrete mutable struct ApproximateJacobianSolveCache{INV, GB, iip} <:
50-
AbstractNonlinearSolveCache{iip}
49+
@concrete mutable struct ApproximateJacobianSolveCache{INV, GB, iip, timeit} <:
50+
AbstractNonlinearSolveCache{iip, timeit}
5151
# Basic Requirements
5252
fu
5353
u
@@ -79,7 +79,7 @@ end
7979
steps_since_last_reset::Int
8080

8181
# Timer
82-
timer::TimerOutput
82+
timer
8383
total_time::Float64 # Simple Counter which works even if TimerOutput is disabled
8484

8585
# Termination & Tracking
@@ -93,8 +93,8 @@ end
9393
store_inverse_jacobian(::ApproximateJacobianSolveCache{INV}) where {INV} = INV
9494

9595
function __reinit_internal!(cache::ApproximateJacobianSolveCache{INV, GB, iip}, args...;
96-
p = cache.p, u0 = cache.u, alias_u0::Bool = false, maxiters = 1000, maxtime = Inf,
97-
kwargs...) where {INV, GB, iip}
96+
p = cache.p, u0 = cache.u, alias_u0::Bool = false, maxiters = 1000,
97+
maxtime = nothing, kwargs...) where {INV, GB, iip}
9898
if iip
9999
recursivecopy!(cache.u, u0)
100100
cache.prob.f(cache.fu, cache.u, p)
@@ -123,12 +123,12 @@ end
123123
@internal_caches ApproximateJacobianSolveCache :initialization_cache :descent_cache :linesearch_cache :trustregion_cache :update_rule_cache :reinit_rule_cache
124124

125125
function SciMLBase.__init(prob::AbstractNonlinearProblem{uType, iip},
126-
alg::ApproximateJacobianSolveAlgorithm, args...; alias_u0 = false, maxtime = Inf,
127-
maxiters = 1000, abstol = nothing, reltol = nothing, linsolve_kwargs = (;),
128-
termination_condition = nothing, internalnorm::F = DEFAULT_NORM,
129-
kwargs...) where {uType, iip, F}
130-
timer = TimerOutput()
131-
@timeit_debug timer "cache construction" begin
126+
alg::ApproximateJacobianSolveAlgorithm, args...; alias_u0 = false,
127+
maxtime = nothing, maxiters = 1000, abstol = nothing, reltol = nothing,
128+
linsolve_kwargs = (;), termination_condition = nothing,
129+
internalnorm::F = DEFAULT_NORM, kwargs...) where {uType, iip, F}
130+
timer = get_timer_output()
131+
@static_timeit timer "cache construction" begin
132132
(; f, u0, p) = prob
133133
u = __maybe_unaliased(u0, alias_u0)
134134
fu = evaluate_f(prob, u)
@@ -181,18 +181,18 @@ function SciMLBase.__init(prob::AbstractNonlinearProblem{uType, iip},
181181
trace = init_nonlinearsolve_trace(alg, u, fu, ApplyArray(__zero, J), du;
182182
uses_jacobian_inverse = Val(INV), kwargs...)
183183

184-
return ApproximateJacobianSolveCache{INV, GB, iip}(fu, u, u_cache, p, du, J, alg,
185-
prob, initialization_cache, descent_cache, linesearch_cache, trustregion_cache,
186-
update_rule_cache, reinit_rule_cache, inv_workspace, 0, 0, 0, alg.max_resets,
187-
maxiters, maxtime, alg.max_shrink_times, 0, timer, 0.0, termination_cache,
188-
trace, ReturnCode.Default, false, false)
184+
return ApproximateJacobianSolveCache{INV, GB, iip, maxtime !== nothing}(fu, u,
185+
u_cache, p, du, J, alg, prob, initialization_cache, descent_cache,
186+
linesearch_cache, trustregion_cache, update_rule_cache, reinit_rule_cache,
187+
inv_workspace, 0, 0, 0, alg.max_resets, maxiters, maxtime, alg.max_shrink_times,
188+
0, timer, 0.0, termination_cache, trace, ReturnCode.Default, false, false)
189189
end
190190
end
191191

192192
function __step!(cache::ApproximateJacobianSolveCache{INV, GB, iip};
193193
recompute_jacobian::Union{Nothing, Bool} = nothing) where {INV, GB, iip}
194194
new_jacobian = true
195-
@timeit_debug cache.timer "jacobian init/reinit" begin
195+
@static_timeit cache.timer "jacobian init/reinit" begin
196196
if get_nsteps(cache) == 0 # First Step is special ignore kwargs
197197
J_init = solve!(cache.initialization_cache, cache.fu, cache.u, Val(false))
198198
if INV
@@ -248,7 +248,7 @@ function __step!(cache::ApproximateJacobianSolveCache{INV, GB, iip};
248248
end
249249
end
250250

251-
@timeit_debug cache.timer "descent" begin
251+
@static_timeit cache.timer "descent" begin
252252
if cache.trustregion_cache !== nothing &&
253253
hasfield(typeof(cache.trustregion_cache), :trust_region)
254254
δu, descent_success, descent_intermediates = solve!(cache.descent_cache,
@@ -262,19 +262,19 @@ function __step!(cache::ApproximateJacobianSolveCache{INV, GB, iip};
262262

263263
if descent_success
264264
if GB === :LineSearch
265-
@timeit_debug cache.timer "linesearch" begin
265+
@static_timeit cache.timer "linesearch" begin
266266
needs_reset, α = solve!(cache.linesearch_cache, cache.u, δu)
267267
end
268268
if needs_reset && cache.steps_since_last_reset > 5 # Reset after a burn-in period
269269
cache.force_reinit = true
270270
else
271-
@timeit_debug cache.timer "step" begin
271+
@static_timeit cache.timer "step" begin
272272
@bb axpy!(α, δu, cache.u)
273273
evaluate_f!(cache, cache.u, cache.p)
274274
end
275275
end
276276
elseif GB === :TrustRegion
277-
@timeit_debug cache.timer "trustregion" begin
277+
@static_timeit cache.timer "trustregion" begin
278278
tr_accepted, u_new, fu_new = solve!(cache.trustregion_cache, J, cache.fu,
279279
cache.u, δu, descent_intermediates)
280280
if tr_accepted
@@ -289,7 +289,7 @@ function __step!(cache::ApproximateJacobianSolveCache{INV, GB, iip};
289289
end
290290
α = true
291291
elseif GB === :None
292-
@timeit_debug cache.timer "step" begin
292+
@static_timeit cache.timer "step" begin
293293
@bb axpy!(1, δu, cache.u)
294294
evaluate_f!(cache, cache.u, cache.p)
295295
end
@@ -313,7 +313,7 @@ function __step!(cache::ApproximateJacobianSolveCache{INV, GB, iip};
313313
return nothing
314314
end
315315

316-
@timeit_debug cache.timer "jacobian update" begin
316+
@static_timeit cache.timer "jacobian update" begin
317317
cache.J = solve!(cache.update_rule_cache, cache.J, cache.fu, cache.u, δu)
318318
callback_into_cache!(cache)
319319
end

src/core/generalized_first_order.jl

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,8 @@ end
4949

5050
concrete_jac(::GeneralizedFirstOrderAlgorithm{CJ}) where {CJ} = CJ
5151

52-
@concrete mutable struct GeneralizedFirstOrderAlgorithmCache{iip, GB} <:
53-
AbstractNonlinearSolveCache{iip}
52+
@concrete mutable struct GeneralizedFirstOrderAlgorithmCache{iip, GB, timeit} <:
53+
AbstractNonlinearSolveCache{iip, timeit}
5454
# Basic Requirements
5555
fu
5656
u
@@ -75,7 +75,7 @@ concrete_jac(::GeneralizedFirstOrderAlgorithm{CJ}) where {CJ} = CJ
7575
max_shrink_times::Int
7676

7777
# Timer
78-
timer::TimerOutput
78+
timer
7979
total_time::Float64 # Simple Counter which works even if TimerOutput is disabled
8080

8181
# State Affect
@@ -89,8 +89,8 @@ concrete_jac(::GeneralizedFirstOrderAlgorithm{CJ}) where {CJ} = CJ
8989
end
9090

9191
function __reinit_internal!(cache::GeneralizedFirstOrderAlgorithmCache{iip}, args...;
92-
p = cache.p, u0 = cache.u, alias_u0::Bool = false, maxiters = 1000, maxtime = Inf,
93-
kwargs...) where {iip}
92+
p = cache.p, u0 = cache.u, alias_u0::Bool = false, maxiters = 1000,
93+
maxtime = nothing, kwargs...) where {iip}
9494
if iip
9595
recursivecopy!(cache.u, u0)
9696
cache.prob.f(cache.fu, cache.u, p)
@@ -118,10 +118,11 @@ end
118118

119119
function SciMLBase.__init(prob::AbstractNonlinearProblem{uType, iip},
120120
alg::GeneralizedFirstOrderAlgorithm, args...; alias_u0 = false, maxiters = 1000,
121-
abstol = nothing, reltol = nothing, maxtime = Inf, termination_condition = nothing,
122-
internalnorm = DEFAULT_NORM, linsolve_kwargs = (;), kwargs...) where {uType, iip}
123-
timer = TimerOutput()
124-
@timeit_debug timer "cache construction" begin
121+
abstol = nothing, reltol = nothing, maxtime = nothing,
122+
termination_condition = nothing, internalnorm = DEFAULT_NORM, linsolve_kwargs = (;),
123+
kwargs...) where {uType, iip}
124+
timer = get_timer_output()
125+
@static_timeit timer "cache construction" begin
125126
(; f, u0, p) = prob
126127
u = __maybe_unaliased(u0, alias_u0)
127128
fu = evaluate_f(prob, u)
@@ -166,16 +167,16 @@ function SciMLBase.__init(prob::AbstractNonlinearProblem{uType, iip},
166167

167168
trace = init_nonlinearsolve_trace(alg, u, fu, ApplyArray(__zero, J), du; kwargs...)
168169

169-
return GeneralizedFirstOrderAlgorithmCache{iip, GB}(fu, u, u_cache, p,
170-
du, J, alg, prob, jac_cache, descent_cache, linesearch_cache,
170+
return GeneralizedFirstOrderAlgorithmCache{iip, GB, maxtime !== nothing}(fu, u,
171+
u_cache, p, du, J, alg, prob, jac_cache, descent_cache, linesearch_cache,
171172
trustregion_cache, 0, 0, maxiters, maxtime, alg.max_shrink_times, timer, 0.0,
172173
true, termination_cache, trace, ReturnCode.Default, false)
173174
end
174175
end
175176

176177
function __step!(cache::GeneralizedFirstOrderAlgorithmCache{iip, GB};
177178
recompute_jacobian::Union{Nothing, Bool} = nothing, kwargs...) where {iip, GB}
178-
@timeit_debug cache.timer "jacobian" begin
179+
@static_timeit cache.timer "jacobian" begin
179180
if (recompute_jacobian === nothing || recompute_jacobian) && cache.make_new_jacobian
180181
J = cache.jac_cache(cache.u)
181182
new_jacobian = true
@@ -185,7 +186,7 @@ function __step!(cache::GeneralizedFirstOrderAlgorithmCache{iip, GB};
185186
end
186187
end
187188

188-
@timeit_debug cache.timer "descent" begin
189+
@static_timeit cache.timer "descent" begin
189190
if cache.trustregion_cache !== nothing &&
190191
hasfield(typeof(cache.trustregion_cache), :trust_region)
191192
δu, descent_success, descent_intermediates = solve!(cache.descent_cache,
@@ -200,19 +201,19 @@ function __step!(cache::GeneralizedFirstOrderAlgorithmCache{iip, GB};
200201
if descent_success
201202
cache.make_new_jacobian = true
202203
if GB === :LineSearch
203-
@timeit_debug cache.timer "linesearch" begin
204+
@static_timeit cache.timer "linesearch" begin
204205
linesearch_failed, α = solve!(cache.linesearch_cache, cache.u, δu)
205206
end
206207
if linesearch_failed
207208
cache.retcode = ReturnCode.InternalLineSearchFailed
208209
cache.force_stop = true
209210
end
210-
@timeit_debug cache.timer "step" begin
211+
@static_timeit cache.timer "step" begin
211212
@bb axpy!(α, δu, cache.u)
212213
evaluate_f!(cache, cache.u, cache.p)
213214
end
214215
elseif GB === :TrustRegion
215-
@timeit_debug cache.timer "trustregion" begin
216+
@static_timeit cache.timer "trustregion" begin
216217
tr_accepted, u_new, fu_new = solve!(cache.trustregion_cache, J, cache.fu,
217218
cache.u, δu, descent_intermediates)
218219
if tr_accepted
@@ -230,7 +231,7 @@ function __step!(cache::GeneralizedFirstOrderAlgorithmCache{iip, GB};
230231
end
231232
end
232233
elseif GB === :None
233-
@timeit_debug cache.timer "step" begin
234+
@static_timeit cache.timer "step" begin
234235
@bb axpy!(1, δu, cache.u)
235236
evaluate_f!(cache, cache.u, cache.p)
236237
end

src/core/generic.jl

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ function SciMLBase.solve!(cache::AbstractNonlinearSolveCache)
2424
update_trace!(cache.trace, get_nsteps(cache), get_u(cache), get_fu(cache), nothing,
2525
nothing, nothing; last = True)
2626

27-
stats = SciMLBase.NLStats(get_nf(cache), get_njacs(cache), get_nfactors(cache),
27+
stats = ImmutableNLStats(get_nf(cache), get_njacs(cache), get_nfactors(cache),
2828
get_nsolve(cache), get_nsteps(cache))
2929

3030
return SciMLBase.build_solution(cache.prob, cache.alg, get_u(cache), get_fu(cache);
@@ -45,18 +45,21 @@ Performs one step of the nonlinear solver.
4545
respectively. For algorithms that don't use jacobian information, this keyword is
4646
ignored with a one-time warning.
4747
"""
48-
function SciMLBase.step!(cache::AbstractNonlinearSolveCache, args...; kwargs...)
49-
time_start = time()
50-
res = @timeit_debug cache.timer "solve" begin
48+
function SciMLBase.step!(cache::AbstractNonlinearSolveCache{iip, timeit}, args...;
49+
kwargs...) where {iip, timeit}
50+
timeit && (time_start = time())
51+
res = @static_timeit cache.timer "solve" begin
5152
__step!(cache, args...; kwargs...)
5253
end
5354
cache.nsteps += 1
54-
cache.total_time += time() - time_start
5555

56-
if !cache.force_stop && cache.retcode == ReturnCode.Default &&
57-
cache.total_time cache.maxtime
58-
cache.retcode = ReturnCode.MaxTime
59-
cache.force_stop = true
56+
if timeit
57+
cache.total_time += time() - time_start
58+
if !cache.force_stop && cache.retcode == ReturnCode.Default &&
59+
cache.total_time cache.maxtime
60+
cache.retcode = ReturnCode.MaxTime
61+
cache.force_stop = true
62+
end
6063
end
6164

6265
return res

0 commit comments

Comments
 (0)