Skip to content

Commit 5797257

Browse files
Merge pull request #514 from pepijndevos/pv/progress
progress bars for EnsembleProblem
2 parents 06d5c2c + a1370a0 commit 5797257

File tree

2 files changed

+111
-30
lines changed

2 files changed

+111
-30
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
3333
TruncatedStacktraces = "781d530d-4396-4725-bb49-402e4bee1e77"
3434

3535
[weakdeps]
36+
ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2"
3637
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
3738
PartialFunctions = "570af359-4316-4cb7-8c74-252c00c2016b"
3839
PyCall = "438e738f-606a-5dbb-bf0a-cddfbfd45ab0"
@@ -51,6 +52,7 @@ SciMLBaseZygoteExt = "Zygote"
5152
[compat]
5253
ADTypes = "0.1.3, 0.2"
5354
ArrayInterface = "6, 7"
55+
ChainRules = "1.57.0"
5456
ChainRulesCore = "1.16"
5557
CommonSolve = "0.2.4"
5658
ConstructionBase = "1"

src/ensemble/basic_ensemble_solve.jl

Lines changed: 109 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,64 @@ function merge_stats(us)
2929
reduce(merge, st)
3030
end
3131

32+
mutable struct AggregateLogger{T<:Logging.AbstractLogger} <: Logging.AbstractLogger
33+
progress::Dict{Symbol, Float64}
34+
done_counter::Int
35+
total::Float64
36+
print_time::Float64
37+
lock::ReentrantLock
38+
logger::T
39+
end
40+
AggregateLogger(logger::Logging.AbstractLogger) = AggregateLogger(Dict{Symbol, Float64}(),0 , 0.0, 0.0, ReentrantLock(), logger)
41+
42+
function Logging.handle_message(l::AggregateLogger, level, message, _module, group, id, file, line; kwargs...)
43+
if convert(Logging.LogLevel, level) == Logging.LogLevel(-1) && haskey(kwargs, :progress)
44+
pr = kwargs[:progress]
45+
if trylock(l.lock) || (pr == "done" && lock(l.lock)===nothing)
46+
try
47+
if pr == "done"
48+
pr = 1.0
49+
l.done_counter += 1
50+
end
51+
len = length(l.progress)
52+
if haskey(l.progress, id)
53+
l.total += (pr-l.progress[id])/len
54+
else
55+
l.total = l.total*(len/(len+1)) + pr/(len+1)
56+
len += 1
57+
end
58+
l.progress[id] = pr
59+
# validation check (slow)
60+
# tot = sum(values(l.progress))/length(l.progress)
61+
# @show tot l.total l.total ≈ tot
62+
curr_time = time()
63+
if l.done_counter >= len
64+
tot="done"
65+
empty!(l.progress)
66+
l.done_counter = 0
67+
l.print_time = 0.0
68+
elseif curr_time-l.print_time > 0.1
69+
tot = l.total
70+
l.print_time = curr_time
71+
else
72+
return
73+
end
74+
id=:total
75+
message="Total"
76+
kwargs=merge(values(kwargs), (progress=tot,))
77+
finally
78+
unlock(l.lock)
79+
end
80+
else
81+
return
82+
end
83+
end
84+
Logging.handle_message(l.logger, level, message, _module, group, id, file, line; kwargs...)
85+
end
86+
Logging.shouldlog(l::AggregateLogger, args...) = Logging.shouldlog(l.logger, args...)
87+
Logging.min_enabled_level(l::AggregateLogger) = Logging.min_enabled_level(l.logger)
88+
Logging.catch_exceptions(l::AggregateLogger) = Logging.catch_exceptions(l.logger)
89+
3290
function __solve(prob::AbstractEnsembleProblem,
3391
alg::Union{AbstractDEAlgorithm, Nothing};
3492
kwargs...)
@@ -59,51 +117,72 @@ end
59117
function __solve(prob::AbstractEnsembleProblem,
60118
alg::A,
61119
ensemblealg::BasicEnsembleAlgorithm;
62-
trajectories, batch_size = trajectories,
120+
trajectories, batch_size = trajectories, progress_aggregate=true,
63121
pmap_batch_size = batch_size ÷ 100 > 0 ? batch_size ÷ 100 : 1, kwargs...) where {A}
64-
num_batches = trajectories ÷ batch_size
65-
num_batches < 1 &&
66-
error("trajectories ÷ batch_size cannot be less than 1, got $num_batches")
67-
num_batches * batch_size != trajectories && (num_batches += 1)
68-
69-
if num_batches == 1 && prob.reduction === DEFAULT_REDUCTION
70-
elapsed_time = @elapsed u = solve_batch(prob, alg, ensemblealg, 1:trajectories,
71-
pmap_batch_size; kwargs...)
72-
_u = tighten_container_eltype(u)
73-
stats = merge_stats(_u)
74-
return EnsembleSolution(_u, elapsed_time, true, stats)
75-
end
122+
logger = progress_aggregate ? AggregateLogger(Logging.current_logger()) : Logging.current_logger()
123+
124+
Logging.with_logger(logger) do
125+
num_batches = trajectories ÷ batch_size
126+
num_batches < 1 &&
127+
error("trajectories ÷ batch_size cannot be less than 1, got $num_batches")
128+
num_batches * batch_size != trajectories && (num_batches += 1)
76129

77-
converged::Bool = false
78-
elapsed_time = @elapsed begin
79-
i = 1
80-
II = (batch_size * (i - 1) + 1):(batch_size * i)
130+
if get(kwargs, :progress, false)
131+
name = get(kwargs, :progress_name, "Ensemble")
132+
for i in 1:trajectories
133+
msg = "$name #$i"
134+
Logging.@logmsg(Logging.LogLevel(-1), msg, _id=Symbol("SciMLBase_$i"), progress=0)
135+
end
136+
end
137+
81138

82-
batch_data = solve_batch(prob, alg, ensemblealg, II, pmap_batch_size; kwargs...)
139+
if num_batches == 1 && prob.reduction === DEFAULT_REDUCTION
140+
elapsed_time = @elapsed u = solve_batch(prob, alg, ensemblealg, 1:trajectories,
141+
pmap_batch_size; kwargs...)
142+
_u = tighten_container_eltype(u)
143+
stats = merge_stats(_u)
144+
return EnsembleSolution(_u, elapsed_time, true, stats)
145+
end
146+
147+
converged::Bool = false
148+
elapsed_time = @elapsed begin
149+
i = 1
150+
II = (batch_size * (i - 1) + 1):(batch_size * i)
83151

84-
u = prob.u_init === nothing ? similar(batch_data, 0) : prob.u_init
85-
u, converged = prob.reduction(u, batch_data, II)
86-
for i in 2:num_batches
87-
converged && break
88-
if i == num_batches
89-
II = (batch_size * (i - 1) + 1):trajectories
90-
else
91-
II = (batch_size * (i - 1) + 1):(batch_size * i)
92-
end
93152
batch_data = solve_batch(prob, alg, ensemblealg, II, pmap_batch_size; kwargs...)
153+
154+
u = prob.u_init === nothing ? similar(batch_data, 0) : prob.u_init
94155
u, converged = prob.reduction(u, batch_data, II)
156+
for i in 2:num_batches
157+
converged && break
158+
if i == num_batches
159+
II = (batch_size * (i - 1) + 1):trajectories
160+
else
161+
II = (batch_size * (i - 1) + 1):(batch_size * i)
162+
end
163+
batch_data = solve_batch(prob, alg, ensemblealg, II, pmap_batch_size; kwargs...)
164+
u, converged = prob.reduction(u, batch_data, II)
165+
end
95166
end
167+
_u = tighten_container_eltype(u)
168+
stats = merge_stats(_u)
169+
return EnsembleSolution(_u, elapsed_time, converged, stats)
96170
end
97-
_u = tighten_container_eltype(u)
98-
stats = merge_stats(_u)
99-
return EnsembleSolution(_u, elapsed_time, converged, stats)
100171
end
101172

102173
function batch_func(i, prob, alg; kwargs...)
103174
iter = 1
104175
_prob = prob.safetycopy ? deepcopy(prob.prob) : prob.prob
105176
new_prob = prob.prob_func(_prob, i, iter)
106177
rerun = true
178+
179+
progress = get(kwargs, :progress, false)
180+
if progress
181+
name = get(kwargs, :progress_name, "Ensemble")
182+
progress_name = "$name #$i"
183+
progress_id = Symbol("SciMLBase_$i")
184+
kwargs = (kwargs..., progress_name=progress_name, progress_id=progress_id)
185+
end
107186
x = prob.output_func(solve(new_prob, alg; kwargs...), i)
108187
if !(x isa Tuple)
109188
rerun_warn()

0 commit comments

Comments
 (0)