Skip to content

Commit ff42d8b

Browse files
Merge pull request #512 from pepijndevos/pv/enstats
add stats to ensemble solution
2 parents a29ecf1 + 35f5f06 commit ff42d8b

File tree

9 files changed

+149
-13
lines changed

9 files changed

+149
-13
lines changed

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
2020
Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a"
2121
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
2222
Preferences = "21216c6a-2e73-6563-6e65-726566657250"
23+
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
2324
RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01"
2425
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
2526
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"

ext/SciMLBaseZygoteExt.jl

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@ module SciMLBaseZygoteExt
22

33
using Zygote: pullback
44
using ZygoteRules: @adjoint
5-
using SciMLBase: ODESolution, issymbollike, sym_to_index, remake, getobserved
5+
import ZygoteRules
6+
using SciMLBase: EnsembleSolution, ODESolution, issymbollike, sym_to_index, remake, getobserved
67

78
# This method resolves the ambiguity with the pullback defined in
89
# RecursiveArrayToolsZygoteExt
@@ -55,4 +56,25 @@ end
5556
VA[sym, j], ODESolution_getindex_pullback
5657
end
5758

59+
ZygoteRules.@adjoint function EnsembleSolution(sim, time, converged, stats)
60+
out = EnsembleSolution(sim, time, converged)
61+
function EnsembleSolution_adjoint(p̄::AbstractArray{T, N}) where {T, N}
62+
arrarr = [[p̄[ntuple(x -> Colon(), Val(N - 2))..., j, i]
63+
for j in 1:size(p̄)[end - 1]] for i in 1:size(p̄)[end]]
64+
(EnsembleSolution(arrarr, 0.0, true), nothing, nothing, nothing)
65+
end
66+
function EnsembleSolution_adjoint(p̄::AbstractArray{<:AbstractArray, 1})
67+
(EnsembleSolution(p̄, 0.0, true), nothing, nothing, nothing)
68+
end
69+
function EnsembleSolution_adjoint(p̄::EnsembleSolution)
70+
(p̄, nothing, nothing, nothing)
71+
end
72+
out, EnsembleSolution_adjoint
73+
end
74+
75+
ZygoteRules.@adjoint function ZygoteRules.literal_getproperty(sim::EnsembleSolution,
76+
::Val{:u})
77+
sim.u, p̄ -> (EnsembleSolution(p̄, 0.0, true),)
78+
end
79+
5880
end

src/SciMLBase.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ using LinearAlgebra
1111
using Statistics
1212
using Distributed
1313
using Markdown
14+
using Printf
1415
import Preferences
1516

1617
import Logging, ArrayInterface

src/ensemble/basic_ensemble_solve.jl

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,12 @@ $(TYPEDEF)
2323
"""
2424
struct EnsembleSerial <: BasicEnsembleAlgorithm end
2525

26+
function merge_stats(us)
27+
st = Iterators.filter(!isnothing, (hasproperty(x, :stats) ? x.stats : nothing for x in us))
28+
isempty(st) && return nothing
29+
reduce(merge, st)
30+
end
31+
2632
function __solve(prob::AbstractEnsembleProblem,
2733
alg::Union{AbstractDEAlgorithm, Nothing};
2834
kwargs...)
@@ -64,7 +70,8 @@ function __solve(prob::AbstractEnsembleProblem,
6470
elapsed_time = @elapsed u = solve_batch(prob, alg, ensemblealg, 1:trajectories,
6571
pmap_batch_size; kwargs...)
6672
_u = tighten_container_eltype(u)
67-
return EnsembleSolution(_u, elapsed_time, true)
73+
stats = merge_stats(_u)
74+
return EnsembleSolution(_u, elapsed_time, true, stats)
6875
end
6976

7077
converged::Bool = false
@@ -88,8 +95,8 @@ function __solve(prob::AbstractEnsembleProblem,
8895
end
8996
end
9097
_u = tighten_container_eltype(u)
91-
92-
return EnsembleSolution(_u, elapsed_time, converged)
98+
stats = merge_stats(_u)
99+
return EnsembleSolution(_u, elapsed_time, converged, stats)
93100
end
94101

95102
function batch_func(i, prob, alg; kwargs...)

src/ensemble/ensemble_solutions.jl

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -29,21 +29,23 @@ struct EnsembleSolution{T, N, S} <: AbstractEnsembleSolution{T, N, S}
2929
u::S
3030
elapsedTime::Float64
3131
converged::Bool
32+
stats
3233
end
33-
function EnsembleSolution(sim, dims::NTuple{N}, elapsedTime, converged) where {N}
34-
EnsembleSolution{eltype(eltype(sim)), N, typeof(sim)}(sim, elapsedTime, converged)
34+
function EnsembleSolution(sim, dims::NTuple{N}, elapsedTime, converged, stats) where {N}
35+
EnsembleSolution{eltype(eltype(sim)), N, typeof(sim)}(sim, elapsedTime, converged, stats)
3536
end
36-
function EnsembleSolution(sim, elapsedTime, converged)
37-
EnsembleSolution(sim, (length(sim),), elapsedTime, converged)
37+
function EnsembleSolution(sim, elapsedTime, converged, stats=nothing)
38+
EnsembleSolution(sim, (length(sim),), elapsedTime, converged, stats)
3839
end # Vector of some type which is not an array
3940
function EnsembleSolution(sim::T, elapsedTime,
40-
converged) where {T <: AbstractVector{T2}
41+
converged, stats=nothing) where {T <: AbstractVector{T2}
4142
} where {T2 <:
4243
AbstractArray}
43-
EnsembleSolution{eltype(eltype(sim)), ndims(sim[1]) + 1,
44-
typeof(sim)}(sim,
44+
EnsembleSolution{eltype(eltype(sim)), ndims(sim[1]) + 1, typeof(sim)}(
45+
sim,
4546
elapsedTime,
46-
converged)
47+
converged,
48+
stats)
4749
end
4850

4951
struct WeightedEnsembleSolution{T1 <: AbstractEnsembleSolution, T2 <: Number}
@@ -56,7 +58,7 @@ struct WeightedEnsembleSolution{T1 <: AbstractEnsembleSolution, T2 <: Number}
5658
end
5759

5860
function Base.reverse(sim::EnsembleSolution)
59-
EnsembleSolution(reverse(sim.u), sim.elapsedTime, sim.converged)
61+
EnsembleSolution(reverse(sim.u), sim.elapsedTime, sim.converged, sim.stats)
6062
end
6163

6264
"""

src/solutions/nonlinear_solutions.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,20 @@ mutable struct NLStats
1717
nsteps::Int
1818
end
1919

20+
function Base.show(io::IO, ::MIME"text/plain", s::NLStats)
21+
println(io, summary(s))
22+
@printf io "%-50s %-d\n" "Number of function evaluations:" s.nf
23+
@printf io "%-50s %-d\n" "Number of Jacobians created:" s.njacs
24+
@printf io "%-50s %-d\n" "Number of factorizations:" s.nfactors
25+
@printf io "%-50s %-d\n" "Number of linear solves:" s.nsolve
26+
@printf io "%-50s %-d" "Number of nonlinear solver iterations:" s.nsteps
27+
end
28+
29+
function Base.merge(s1::NLStats, s2::NLStats)
30+
NLStats(s1.nf + s2.nf, s1.njacs + s2.njacs, s1.nfactors + s2.nfactors,
31+
s1.nsolve + s2.nsolve, s1.nsteps + s2.nsteps)
32+
end
33+
2034
"""
2135
$(TYPEDEF)
2236

src/solutions/ode_solutions.jl

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,79 @@
11
"""
22
$(TYPEDEF)
33
4+
Statistics from the differential equation solver about the solution process.
5+
6+
## Fields
7+
8+
- nf: Number of function evaluations. If the differential equation is a split function,
9+
such as a `SplitFunction` for implicit-explicit (IMEX) integration, then `nf` is the
10+
number of function evaluations for the first function (the implicit function)
11+
- nf2: If the differential equation is a split function, such as a `SplitFunction`
12+
for implicit-explicit (IMEX) integration, then `nf2` is the number of function
13+
evaluations for the second function, i.e. the function treated explicitly. Otherwise
14+
it is zero.
15+
- nw: The number of W=I-gamma*J (or W=I/gamma-J) matrices constructed during the solving
16+
process.
17+
- nsolve: The number of linear solves `W\b` required for the integration.
18+
- njacs: Number of Jacobians calculated during the integration.
19+
- nnonliniter: Total number of iterations for the nonlinear solvers.
20+
- nnonlinconvfail: Number of nonlinear solver convergence failures.
21+
- ncondition: Number of calls to the condition function for callbacks.
22+
- naccept: Number of accepted steps.
23+
- nreject: Number of rejected steps.
24+
- maxeig: Maximum eigenvalue over the solution. This is only computed if the
25+
method is an auto-switching algorithm.
26+
"""
27+
mutable struct DEStats
28+
nf::Int
29+
nf2::Int
30+
nw::Int
31+
nsolve::Int
32+
njacs::Int
33+
nnonliniter::Int
34+
nnonlinconvfail::Int
35+
ncondition::Int
36+
naccept::Int
37+
nreject::Int
38+
maxeig::Float64
39+
end
40+
41+
DEStats(x::Int = -1) = DEStats(x, x, x, x, x, x, x, x, x, x, 0.0)
42+
43+
function Base.show(io::IO, ::MIME"text/plain", s::DEStats)
44+
println(io, summary(s))
45+
@printf io "%-50s %-d\n" "Number of function 1 evaluations:" s.nf
46+
@printf io "%-50s %-d\n" "Number of function 2 evaluations:" s.nf2
47+
@printf io "%-50s %-d\n" "Number of W matrix evaluations:" s.nw
48+
@printf io "%-50s %-d\n" "Number of linear solves:" s.nsolve
49+
@printf io "%-50s %-d\n" "Number of Jacobians created:" s.njacs
50+
@printf io "%-50s %-d\n" "Number of nonlinear solver iterations:" s.nnonliniter
51+
@printf io "%-50s %-d\n" "Number of nonlinear solver convergence failures:" s.nnonlinconvfail
52+
@printf io "%-50s %-d\n" "Number of rootfind condition calls:" s.ncondition
53+
@printf io "%-50s %-d\n" "Number of accepted steps:" s.naccept
54+
@printf io "%-50s %-d" "Number of rejected steps:" s.nreject
55+
iszero(s.maxeig) || @printf io "\n%-50s %-d" "Maximum eigenvalue recorded:" s.maxeig
56+
end
57+
58+
function Base.merge(a::DEStats, b::DEStats)
59+
DEStats(
60+
a.nf + b.nf,
61+
a.nf2 + b.nf2,
62+
a.nw + b.nw,
63+
a.nsolve + b.nsolve,
64+
a.njacs + b.njacs,
65+
a.nnonliniter + b.nnonliniter,
66+
a.nnonlinconvfail + b.nnonlinconvfail,
67+
a.ncondition + b.ncondition,
68+
a.naccept + b.naccept,
69+
a.nreject + b.nreject,
70+
max(a.maxeig, b.maxeig),
71+
)
72+
end
73+
74+
"""
75+
$(TYPEDEF)
76+
477
Representation of the solution to an ordinary differential equation defined by an ODEProblem.
578
679
## DESolution Interface

test/downstream/ensemble_stats.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
using OrdinaryDiffEq
2+
using Test
3+
4+
f(u,p,t) = 1.01*u
5+
u0=1/2
6+
tspan = (0.0,1.0)
7+
prob = ODEProblem(f,u0,tspan)
8+
function prob_func(prob, i, repeat)
9+
remake(prob, u0 = rand() * prob.u0)
10+
end
11+
ensemble_prob = EnsembleProblem(prob, prob_func = prob_func)
12+
sim = solve(ensemble_prob, Tsit5(), EnsembleThreads(), trajectories = 10)
13+
@test sim.stats.nf == mapreduce(x -> x.stats.nf, +, sim.u)

test/runtests.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,9 @@ end
7575
@time @safetestset "solving Ensembles with multiple problems" begin
7676
include("downstream/ensemble_multi_prob.jl")
7777
end
78+
@time @safetestset "Ensemble solution statistics" begin
79+
include("downstream/ensemble_stats.jl")
80+
end
7881
@time @safetestset "Symbol and integer based indexing of interpolated solutions" begin
7982
include("downstream/symbol_indexing.jl")
8083
end

0 commit comments

Comments
 (0)