diff --git a/Project.toml b/Project.toml index 77de8700..213f92a1 100644 --- a/Project.toml +++ b/Project.toml @@ -18,6 +18,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" TerminalLoggers = "5d786b92-1e48-4d6f-9151-6b4477ca9bed" Transducers = "28d57a85-8fef-5791-bfe6-a80928e7c999" +UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" [compat] BangBang = "0.3.19, 0.4" @@ -29,6 +30,7 @@ ProgressLogging = "0.1" StatsBase = "0.32, 0.33, 0.34" TerminalLoggers = "0.1" Transducers = "0.4.30" +UUIDs = "1.11.0" julia = "1.6" [extras] diff --git a/src/logging.jl b/src/logging.jl index 04c41187..7abeec76 100644 --- a/src/logging.jl +++ b/src/logging.jl @@ -4,7 +4,7 @@ macro ifwithprogresslogger(progress, exprs...) return esc( quote - if $progress + if $progress == true if $hasprogresslevel($Logging.current_logger()) $ProgressLogging.@withprogress $(exprs...) else diff --git a/src/sample.jl b/src/sample.jl index 01a2006a..259f9df4 100644 --- a/src/sample.jl +++ b/src/sample.jl @@ -1,3 +1,5 @@ +using UUIDs: uuid4 + # Default implementations of `sample`. const PROGRESS = Ref(true) @@ -119,6 +121,7 @@ function mcmcsample( thinning=1, chain_type::Type=Any, initial_state=nothing, + _progress_channel=nothing, kwargs..., ) # Check the number of requested samples. @@ -144,9 +147,19 @@ function mcmcsample( @ifwithprogresslogger progress name = progressname begin # Determine threshold values for progress logging # (one update per 0.5% of progress) - if progress - threshold = Ntotal ÷ 200 - next_update = threshold + threshold = Ntotal ÷ 200 + next_update = threshold + + # Ugly hacky code to reset the start timer if called from a multi-chain + # sampling process + # TODO: How to make this better? + if progress isa ProgressLogging.Progress + try + bartrees = Logging.current_logger().loggers[1].logger.bartrees + bar = TerminalLoggers.findbar(bartrees, progress.id).data + bar.tfirst = time() + catch + end end # Obtain the initial sample and state. @@ -166,8 +179,13 @@ function mcmcsample( # Update the progress bar. itotal = 1 - if progress && itotal >= next_update - ProgressLogging.@logprogress itotal / Ntotal + if itotal >= next_update + if progress == true + ProgressLogging.@logprogress itotal / Ntotal + elseif progress isa ProgressLogging.Progress + ProgressLogging.@logprogress name = progressname itotal / Ntotal _id = + progress.id + end next_update = itotal + threshold end @@ -181,8 +199,14 @@ function mcmcsample( end # Update the progress bar. - if progress && (itotal += 1) >= next_update - ProgressLogging.@logprogress itotal / Ntotal + _progress_channel !== nothing && put!(_progress_channel, true) + if (itotal += 1) >= next_update + if progress == true + ProgressLogging.@logprogress itotal / Ntotal + elseif progress isa ProgressLogging.Progress + ProgressLogging.@logprogress name = progressname itotal / Ntotal _id = + progress.id + end next_update = itotal + threshold end end @@ -206,8 +230,13 @@ function mcmcsample( end # Update progress bar. - if progress && (itotal += 1) >= next_update - ProgressLogging.@logprogress itotal / Ntotal + if (itotal += 1) >= next_update + if progress == true + ProgressLogging.@logprogress itotal / Ntotal + elseif progress isa ProgressLogging.Progress + ProgressLogging.@logprogress name = progressname itotal / Ntotal _id = + progress.id + end next_update = itotal + threshold end end @@ -227,8 +256,14 @@ function mcmcsample( samples = save!!(samples, sample, i, model, sampler, N; kwargs...) # Update the progress bar. - if progress && (itotal += 1) >= next_update - ProgressLogging.@logprogress itotal / Ntotal + _progress_channel !== nothing && put!(_progress_channel, true) + if (itotal += 1) >= next_update + if progress == true + ProgressLogging.@logprogress itotal / Ntotal + elseif progress isa ProgressLogging.Progress + ProgressLogging.@logprogress name = progressname itotal / Ntotal _id = + progress.id + end next_update = itotal + threshold end end @@ -416,6 +451,15 @@ function mcmcsample( if progress channel = Channel{Bool}(length(interval)) end + # Generate nchains independent UUIDs for each progress bar + # uuids = [uuid4() for _ in 1:nchains] + # Start the progress bars (but in reverse order, because + # ProgressLogging prints from the bottom up, and we want chain 1 to + # show up at the top) + # for (i, uuid) in enumerate(reverse(uuids)) + # ProgressLogging.@logprogress name = "Chain $(nchains-i+1)/$nchains" nothing _id = + # uuid + # end Distributed.@sync begin if progress @@ -423,15 +467,16 @@ function mcmcsample( Distributed.@async begin # Determine threshold values for progress logging # (one update per 0.5% of progress) - threshold = nchains ÷ 200 - nextprogresschains = threshold + nprogupdates = nchains * N + threshold = nprogupdates ÷ 200 + counter = 0 + next_update = threshold - progresschains = 0 while take!(channel) - progresschains += 1 - if progresschains >= nextprogresschains - ProgressLogging.@logprogress progresschains / nchains - nextprogresschains = progresschains + threshold + counter += 1 + if counter >= next_update + ProgressLogging.@logprogress counter / nprogupdates + next_update = next_update + threshold end end end @@ -472,16 +517,24 @@ function mcmcsample( else initial_state[chainidx] end, + _progress_channel=channel, kwargs..., ) + # ProgressLogging.@logprogress name = child_progressname "done" _id = uuids[chainidx] + # # Update the progress bar. progress && put!(channel, true) end end finally - # Stop updating the progress bar. + # Stop updating the progress bars (either if sampling is done, or if + # an error occurs). progress && put!(channel, false) + # for (i, uuid) in enumerate(reverse(uuids)) + # ProgressLogging.@logprogress name = "Chain $(nchains-i+1)/$nchains" "done" _id = + # uuid + # end end end end