Skip to content

Commit 1ecae93

Browse files
author
Pepijn de Vos
committed
progress bars for EnsembleProblem
ODE->Ensemble optionally aggregate progress bars handle integer loglevel (sundials) avoid lock contention only log significant progress improve progress performance add version constraint Update Project.toml Update Project.toml ignore derivatives of logging more AD fixing attempts only pass progress_id if needed use ignore_deriviative and fix rrule for with_logger delete rules moved to ChainRules remove using import as opposed to using more import fixes more missing Logging qualifiers
1 parent 06d5c2c commit 1ecae93

File tree

1 file changed

+108
-30
lines changed

1 file changed

+108
-30
lines changed

src/ensemble/basic_ensemble_solve.jl

Lines changed: 108 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,64 @@ function merge_stats(us)
2929
reduce(merge, st)
3030
end
3131

32+
mutable struct AggregateLogger{T<:Logging.AbstractLogger} <: Logging.AbstractLogger
33+
progress::Dict{Symbol, Float64}
34+
done_counter::Int
35+
total::Float64
36+
print_time::Float64
37+
lock::ReentrantLock
38+
logger::T
39+
end
40+
AggregateLogger(logger::Logging.AbstractLogger) = AggregateLogger(Dict{Symbol, Float64}(),0 , 0.0, 0.0, ReentrantLock(), logger)
41+
42+
function Logging.handle_message(l::AggregateLogger, level, message, _module, group, id, file, line; kwargs...)
43+
if convert(Logging.LogLevel, level) == Logging.LogLevel(-1) && haskey(kwargs, :progress)
44+
pr = kwargs[:progress]
45+
if trylock(l.lock) || (pr == "done" && lock(l.lock)===nothing)
46+
try
47+
if pr == "done"
48+
pr = 1.0
49+
l.done_counter += 1
50+
end
51+
len = length(l.progress)
52+
if haskey(l.progress, id)
53+
l.total += (pr-l.progress[id])/len
54+
else
55+
l.total = l.total*(len/(len+1)) + pr/(len+1)
56+
len += 1
57+
end
58+
l.progress[id] = pr
59+
# validation check (slow)
60+
# tot = sum(values(l.progress))/length(l.progress)
61+
# @show tot l.total l.total ≈ tot
62+
curr_time = time()
63+
if l.done_counter >= len
64+
tot="done"
65+
empty!(l.progress)
66+
l.done_counter = 0
67+
l.print_time = 0.0
68+
elseif curr_time-l.print_time > 0.1
69+
tot = l.total
70+
l.print_time = curr_time
71+
else
72+
return
73+
end
74+
id=:total
75+
message="Total"
76+
kwargs=merge(values(kwargs), (progress=tot,))
77+
finally
78+
unlock(l.lock)
79+
end
80+
else
81+
return
82+
end
83+
end
84+
Logging.handle_message(l.logger, level, message, _module, group, id, file, line; kwargs...)
85+
end
86+
Logging.shouldlog(l::AggregateLogger, args...) = Logging.shouldlog(l.logger, args...)
87+
Logging.min_enabled_level(l::AggregateLogger) = Logging.min_enabled_level(l.logger)
88+
Logging.catch_exceptions(l::AggregateLogger) = Logging.catch_exceptions(l.logger)
89+
3290
function __solve(prob::AbstractEnsembleProblem,
3391
alg::Union{AbstractDEAlgorithm, Nothing};
3492
kwargs...)
@@ -59,51 +117,71 @@ end
59117
function __solve(prob::AbstractEnsembleProblem,
60118
alg::A,
61119
ensemblealg::BasicEnsembleAlgorithm;
62-
trajectories, batch_size = trajectories,
120+
trajectories, batch_size = trajectories, progress_aggregate=true,
63121
pmap_batch_size = batch_size ÷ 100 > 0 ? batch_size ÷ 100 : 1, kwargs...) where {A}
64-
num_batches = trajectories ÷ batch_size
65-
num_batches < 1 &&
66-
error("trajectories ÷ batch_size cannot be less than 1, got $num_batches")
67-
num_batches * batch_size != trajectories && (num_batches += 1)
68-
69-
if num_batches == 1 && prob.reduction === DEFAULT_REDUCTION
70-
elapsed_time = @elapsed u = solve_batch(prob, alg, ensemblealg, 1:trajectories,
71-
pmap_batch_size; kwargs...)
72-
_u = tighten_container_eltype(u)
73-
stats = merge_stats(_u)
74-
return EnsembleSolution(_u, elapsed_time, true, stats)
75-
end
122+
logger = progress_aggregate ? AggregateLogger(Logging.current_logger()) : Logging.current_logger()
123+
124+
Logging.with_logger(logger) do
125+
num_batches = trajectories ÷ batch_size
126+
num_batches < 1 &&
127+
error("trajectories ÷ batch_size cannot be less than 1, got $num_batches")
128+
num_batches * batch_size != trajectories && (num_batches += 1)
76129

77-
converged::Bool = false
78-
elapsed_time = @elapsed begin
79-
i = 1
80-
II = (batch_size * (i - 1) + 1):(batch_size * i)
130+
if get(kwargs, :progress, false)
131+
name = get(kwargs, :progress_name, "Ensemble")
132+
for i in 1:trajectories
133+
msg = "$name #$i"
134+
Logging.@logmsg(Logging.LogLevel(-1), msg, _id=Symbol("SciMLBase_$i"), progress=0)
135+
end
136+
end
137+
81138

82-
batch_data = solve_batch(prob, alg, ensemblealg, II, pmap_batch_size; kwargs...)
139+
if num_batches == 1 && prob.reduction === DEFAULT_REDUCTION
140+
elapsed_time = @elapsed u = solve_batch(prob, alg, ensemblealg, 1:trajectories,
141+
pmap_batch_size; kwargs...)
142+
_u = tighten_container_eltype(u)
143+
return EnsembleSolution(_u, elapsed_time, true)
144+
end
145+
146+
converged::Bool = false
147+
elapsed_time = @elapsed begin
148+
i = 1
149+
II = (batch_size * (i - 1) + 1):(batch_size * i)
83150

84-
u = prob.u_init === nothing ? similar(batch_data, 0) : prob.u_init
85-
u, converged = prob.reduction(u, batch_data, II)
86-
for i in 2:num_batches
87-
converged && break
88-
if i == num_batches
89-
II = (batch_size * (i - 1) + 1):trajectories
90-
else
91-
II = (batch_size * (i - 1) + 1):(batch_size * i)
92-
end
93151
batch_data = solve_batch(prob, alg, ensemblealg, II, pmap_batch_size; kwargs...)
152+
153+
u = prob.u_init === nothing ? similar(batch_data, 0) : prob.u_init
94154
u, converged = prob.reduction(u, batch_data, II)
155+
for i in 2:num_batches
156+
converged && break
157+
if i == num_batches
158+
II = (batch_size * (i - 1) + 1):trajectories
159+
else
160+
II = (batch_size * (i - 1) + 1):(batch_size * i)
161+
end
162+
batch_data = solve_batch(prob, alg, ensemblealg, II, pmap_batch_size; kwargs...)
163+
u, converged = prob.reduction(u, batch_data, II)
164+
end
95165
end
166+
_u = tighten_container_eltype(u)
167+
stats = merge_stats(_u)
168+
return EnsembleSolution(_u, elapsed_time, converged, stats)
96169
end
97-
_u = tighten_container_eltype(u)
98-
stats = merge_stats(_u)
99-
return EnsembleSolution(_u, elapsed_time, converged, stats)
100170
end
101171

102172
function batch_func(i, prob, alg; kwargs...)
103173
iter = 1
104174
_prob = prob.safetycopy ? deepcopy(prob.prob) : prob.prob
105175
new_prob = prob.prob_func(_prob, i, iter)
106176
rerun = true
177+
178+
progress = get(kwargs, :progress, false)
179+
if progress
180+
name = get(kwargs, :progress_name, "Ensemble")
181+
progress_name = "$name #$i"
182+
progress_id = Symbol("SciMLBase_$i")
183+
kwargs = (kwargs..., progress_name=progress_name, progress_id=progress_id)
184+
end
107185
x = prob.output_func(solve(new_prob, alg; kwargs...), i)
108186
if !(x isa Tuple)
109187
rerun_warn()

0 commit comments

Comments
 (0)