Skip to content

per-chain (well, not really) progress bars (part 3) #167

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
TerminalLoggers = "5d786b92-1e48-4d6f-9151-6b4477ca9bed"
Transducers = "28d57a85-8fef-5791-bfe6-a80928e7c999"
UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"

[compat]
BangBang = "0.3.19, 0.4"
Expand All @@ -29,6 +30,7 @@ ProgressLogging = "0.1"
StatsBase = "0.32, 0.33, 0.34"
TerminalLoggers = "0.1"
Transducers = "0.4.30"
UUIDs = "1.11.0"
julia = "1.6"

[extras]
Expand Down
2 changes: 1 addition & 1 deletion src/logging.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
macro ifwithprogresslogger(progress, exprs...)
return esc(
quote
if $progress
if $progress == true
if $hasprogresslevel($Logging.current_logger())
$ProgressLogging.@withprogress $(exprs...)
else
Expand Down
91 changes: 72 additions & 19 deletions src/sample.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
using UUIDs: uuid4

# Default implementations of `sample`.
const PROGRESS = Ref(true)

Expand Down Expand Up @@ -119,6 +121,7 @@ function mcmcsample(
thinning=1,
chain_type::Type=Any,
initial_state=nothing,
_progress_channel=nothing,
kwargs...,
)
# Check the number of requested samples.
Expand All @@ -144,9 +147,19 @@ function mcmcsample(
@ifwithprogresslogger progress name = progressname begin
# Determine threshold values for progress logging
# (one update per 0.5% of progress)
if progress
threshold = Ntotal ÷ 200
next_update = threshold
threshold = Ntotal ÷ 200
next_update = threshold

# Ugly hacky code to reset the start timer if called from a multi-chain
# sampling process
# TODO: How to make this better?
if progress isa ProgressLogging.Progress
try
bartrees = Logging.current_logger().loggers[1].logger.bartrees
bar = TerminalLoggers.findbar(bartrees, progress.id).data
bar.tfirst = time()
catch
end
end

# Obtain the initial sample and state.
Expand All @@ -166,8 +179,13 @@ function mcmcsample(

# Update the progress bar.
itotal = 1
if progress && itotal >= next_update
ProgressLogging.@logprogress itotal / Ntotal
if itotal >= next_update
if progress == true
ProgressLogging.@logprogress itotal / Ntotal
elseif progress isa ProgressLogging.Progress
ProgressLogging.@logprogress name = progressname itotal / Ntotal _id =
progress.id
end
next_update = itotal + threshold
end

Expand All @@ -181,8 +199,14 @@ function mcmcsample(
end

# Update the progress bar.
if progress && (itotal += 1) >= next_update
ProgressLogging.@logprogress itotal / Ntotal
_progress_channel !== nothing && put!(_progress_channel, true)
if (itotal += 1) >= next_update
if progress == true
ProgressLogging.@logprogress itotal / Ntotal
elseif progress isa ProgressLogging.Progress
ProgressLogging.@logprogress name = progressname itotal / Ntotal _id =
progress.id
end
next_update = itotal + threshold
end
end
Expand All @@ -206,8 +230,13 @@ function mcmcsample(
end

# Update progress bar.
if progress && (itotal += 1) >= next_update
ProgressLogging.@logprogress itotal / Ntotal
if (itotal += 1) >= next_update
if progress == true
ProgressLogging.@logprogress itotal / Ntotal
elseif progress isa ProgressLogging.Progress
ProgressLogging.@logprogress name = progressname itotal / Ntotal _id =
progress.id
end
next_update = itotal + threshold
end
end
Expand All @@ -227,8 +256,14 @@ function mcmcsample(
samples = save!!(samples, sample, i, model, sampler, N; kwargs...)

# Update the progress bar.
if progress && (itotal += 1) >= next_update
ProgressLogging.@logprogress itotal / Ntotal
_progress_channel !== nothing && put!(_progress_channel, true)
if (itotal += 1) >= next_update
if progress == true
ProgressLogging.@logprogress itotal / Ntotal
elseif progress isa ProgressLogging.Progress
ProgressLogging.@logprogress name = progressname itotal / Ntotal _id =
progress.id
end
next_update = itotal + threshold
end
end
Expand Down Expand Up @@ -416,22 +451,32 @@ function mcmcsample(
if progress
channel = Channel{Bool}(length(interval))
end
# Generate nchains independent UUIDs for each progress bar
# uuids = [uuid4() for _ in 1:nchains]
# Start the progress bars (but in reverse order, because
# ProgressLogging prints from the bottom up, and we want chain 1 to
# show up at the top)
# for (i, uuid) in enumerate(reverse(uuids))
# ProgressLogging.@logprogress name = "Chain $(nchains-i+1)/$nchains" nothing _id =
# uuid
# end

Distributed.@sync begin
if progress
# Update the progress bar.
Distributed.@async begin
# Determine threshold values for progress logging
# (one update per 0.5% of progress)
threshold = nchains ÷ 200
nextprogresschains = threshold
nprogupdates = nchains * N
threshold = nprogupdates ÷ 200
counter = 0
next_update = threshold

progresschains = 0
while take!(channel)
progresschains += 1
if progresschains >= nextprogresschains
ProgressLogging.@logprogress progresschains / nchains
nextprogresschains = progresschains + threshold
counter += 1
if counter >= next_update
ProgressLogging.@logprogress counter / nprogupdates
next_update = next_update + threshold
end
end
end
Expand Down Expand Up @@ -472,16 +517,24 @@ function mcmcsample(
else
initial_state[chainidx]
end,
_progress_channel=channel,
kwargs...,
)

# ProgressLogging.@logprogress name = child_progressname "done" _id = uuids[chainidx]
#
# Update the progress bar.
progress && put!(channel, true)
end
end
finally
# Stop updating the progress bar.
# Stop updating the progress bars (either if sampling is done, or if
# an error occurs).
progress && put!(channel, false)
# for (i, uuid) in enumerate(reverse(uuids))
# ProgressLogging.@logprogress name = "Chain $(nchains-i+1)/$nchains" "done" _id =
# uuid
# end
end
end
end
Expand Down
Loading