Skip to content

per-chain progress bars (part 2) #166

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
TerminalLoggers = "5d786b92-1e48-4d6f-9151-6b4477ca9bed"
Transducers = "28d57a85-8fef-5791-bfe6-a80928e7c999"
UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"

[compat]
BangBang = "0.3.19, 0.4"
Expand All @@ -29,6 +30,7 @@ ProgressLogging = "0.1"
StatsBase = "0.32, 0.33, 0.34"
TerminalLoggers = "0.1"
Transducers = "0.4.30"
UUIDs = "1.11.0"
julia = "1.6"

[extras]
Expand Down
2 changes: 1 addition & 1 deletion src/logging.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
macro ifwithprogresslogger(progress, exprs...)
return esc(
quote
if $progress
if $progress == true
if $hasprogresslevel($Logging.current_logger())
$ProgressLogging.@withprogress $(exprs...)
else
Expand Down
81 changes: 70 additions & 11 deletions src/sample.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
using UUIDs: uuid4

# Default implementations of `sample`.
const PROGRESS = Ref(true)

Expand Down Expand Up @@ -144,11 +146,23 @@ function mcmcsample(
@ifwithprogresslogger progress name = progressname begin
# Determine threshold values for progress logging
# (one update per 0.5% of progress)
if progress
if !(progress == false)
threshold = Ntotal ÷ 200
next_update = threshold
end

# Ugly hacky code to reset the start timer if called from a multi-chain
# sampling process
# TODO: How to make this better?
if progress isa ProgressLogging.Progress
try
bartrees = Logging.current_logger().loggers[1].logger.bartrees
bar = TerminalLoggers.findbar(bartrees, progress.id).data
bar.tfirst = time()
catch
end
end

# Obtain the initial sample and state.
sample, state = if num_warmup > 0
if initial_state === nothing
Expand All @@ -166,8 +180,13 @@ function mcmcsample(

# Update the progress bar.
itotal = 1
if progress && itotal >= next_update
ProgressLogging.@logprogress itotal / Ntotal
if !(progress == false) && itotal >= next_update
if progress == true
ProgressLogging.@logprogress itotal / Ntotal
else
ProgressLogging.@logprogress name = progressname itotal / Ntotal _id =
progress.id
end
next_update = itotal + threshold
end

Expand All @@ -181,8 +200,13 @@ function mcmcsample(
end

# Update the progress bar.
if progress && (itotal += 1) >= next_update
ProgressLogging.@logprogress itotal / Ntotal
if !(progress == false) && (itotal += 1) >= next_update
if progress == true
ProgressLogging.@logprogress itotal / Ntotal
else
ProgressLogging.@logprogress name = progressname itotal / Ntotal _id =
progress.id
end
next_update = itotal + threshold
end
end
Expand All @@ -206,8 +230,13 @@ function mcmcsample(
end

# Update progress bar.
if progress && (itotal += 1) >= next_update
ProgressLogging.@logprogress itotal / Ntotal
if !(progress == false) && (itotal += 1) >= next_update
if progress == true
ProgressLogging.@logprogress itotal / Ntotal
else
ProgressLogging.@logprogress name = progressname itotal / Ntotal _id =
progress.id
end
next_update = itotal + threshold
end
end
Expand All @@ -227,8 +256,13 @@ function mcmcsample(
samples = save!!(samples, sample, i, model, sampler, N; kwargs...)

# Update the progress bar.
if progress && (itotal += 1) >= next_update
ProgressLogging.@logprogress itotal / Ntotal
if !(progress == false) && (itotal += 1) >= next_update
if progress == true
ProgressLogging.@logprogress itotal / Ntotal
else
ProgressLogging.@logprogress name = progressname itotal / Ntotal _id =
progress.id
end
next_update = itotal + threshold
end
end
Expand Down Expand Up @@ -416,6 +450,15 @@ function mcmcsample(
if progress
channel = Channel{Bool}(length(interval))
end
# Generate nchains independent UUIDs for each progress bar
uuids = [uuid4() for _ in 1:nchains]
# Start the progress bars (but in reverse order, because
# ProgressLogging prints from the bottom up, and we want chain 1 to
# show up at the top)
for (i, uuid) in enumerate(reverse(uuids))
ProgressLogging.@logprogress name = "Chain $(nchains-i+1)/$nchains" nothing _id =
uuid
end

Distributed.@sync begin
if progress
Expand Down Expand Up @@ -456,12 +499,21 @@ function mcmcsample(
Random.seed!(_rng, seeds[chainidx])

# Sample a chain and save it to the vector.
child_progressname = "Chain $chainidx/$nchains"
child_progress = if progress == false
false
else
ProgressLogging.Progress(
uuids[chainidx]; name=child_progressname
)
end
chains[chainidx] = StatsBase.sample(
_rng,
_model,
_sampler,
N;
progress=false,
progress=child_progress,
progressname=child_progressname,
initial_params=if initial_params === nothing
nothing
else
Expand All @@ -475,13 +527,20 @@ function mcmcsample(
kwargs...,
)

ProgressLogging.@logprogress name = child_progressname "done" _id = uuids[chainidx]

# Update the progress bar.
progress && put!(channel, true)
end
end
finally
# Stop updating the progress bar.
# Stop updating the progress bars (either if sampling is done, or if
# an error occurs).
progress && put!(channel, false)
for (i, uuid) in enumerate(reverse(uuids))
ProgressLogging.@logprogress name = "Chain $(nchains-i+1)/$nchains" "done" _id =
uuid
end
end
end
end
Expand Down
Loading