diff --git a/Project.toml b/Project.toml index 77de870..213f92a 100644 --- a/Project.toml +++ b/Project.toml @@ -18,6 +18,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" TerminalLoggers = "5d786b92-1e48-4d6f-9151-6b4477ca9bed" Transducers = "28d57a85-8fef-5791-bfe6-a80928e7c999" +UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" [compat] BangBang = "0.3.19, 0.4" @@ -29,6 +30,7 @@ ProgressLogging = "0.1" StatsBase = "0.32, 0.33, 0.34" TerminalLoggers = "0.1" Transducers = "0.4.30" +UUIDs = "1.11.0" julia = "1.6" [extras] diff --git a/src/logging.jl b/src/logging.jl index 04c4118..7abeec7 100644 --- a/src/logging.jl +++ b/src/logging.jl @@ -4,7 +4,7 @@ macro ifwithprogresslogger(progress, exprs...) return esc( quote - if $progress + if $progress == true if $hasprogresslevel($Logging.current_logger()) $ProgressLogging.@withprogress $(exprs...) else diff --git a/src/sample.jl b/src/sample.jl index 01a2006..10da7e6 100644 --- a/src/sample.jl +++ b/src/sample.jl @@ -1,3 +1,5 @@ +using UUIDs: uuid4 + # Default implementations of `sample`. const PROGRESS = Ref(true) @@ -144,11 +146,23 @@ function mcmcsample( @ifwithprogresslogger progress name = progressname begin # Determine threshold values for progress logging # (one update per 0.5% of progress) - if progress + if !(progress == false) threshold = Ntotal รท 200 next_update = threshold end + # Ugly hacky code to reset the start timer if called from a multi-chain + # sampling process + # TODO: How to make this better? + if progress isa ProgressLogging.Progress + try + bartrees = Logging.current_logger().loggers[1].logger.bartrees + bar = TerminalLoggers.findbar(bartrees, progress.id).data + bar.tfirst = time() + catch + end + end + # Obtain the initial sample and state. sample, state = if num_warmup > 0 if initial_state === nothing @@ -166,8 +180,13 @@ function mcmcsample( # Update the progress bar. itotal = 1 - if progress && itotal >= next_update - ProgressLogging.@logprogress itotal / Ntotal + if !(progress == false) && itotal >= next_update + if progress == true + ProgressLogging.@logprogress itotal / Ntotal + else + ProgressLogging.@logprogress name = progressname itotal / Ntotal _id = + progress.id + end next_update = itotal + threshold end @@ -181,8 +200,13 @@ function mcmcsample( end # Update the progress bar. - if progress && (itotal += 1) >= next_update - ProgressLogging.@logprogress itotal / Ntotal + if !(progress == false) && (itotal += 1) >= next_update + if progress == true + ProgressLogging.@logprogress itotal / Ntotal + else + ProgressLogging.@logprogress name = progressname itotal / Ntotal _id = + progress.id + end next_update = itotal + threshold end end @@ -206,8 +230,13 @@ function mcmcsample( end # Update progress bar. - if progress && (itotal += 1) >= next_update - ProgressLogging.@logprogress itotal / Ntotal + if !(progress == false) && (itotal += 1) >= next_update + if progress == true + ProgressLogging.@logprogress itotal / Ntotal + else + ProgressLogging.@logprogress name = progressname itotal / Ntotal _id = + progress.id + end next_update = itotal + threshold end end @@ -227,8 +256,13 @@ function mcmcsample( samples = save!!(samples, sample, i, model, sampler, N; kwargs...) # Update the progress bar. - if progress && (itotal += 1) >= next_update - ProgressLogging.@logprogress itotal / Ntotal + if !(progress == false) && (itotal += 1) >= next_update + if progress == true + ProgressLogging.@logprogress itotal / Ntotal + else + ProgressLogging.@logprogress name = progressname itotal / Ntotal _id = + progress.id + end next_update = itotal + threshold end end @@ -416,6 +450,15 @@ function mcmcsample( if progress channel = Channel{Bool}(length(interval)) end + # Generate nchains independent UUIDs for each progress bar + uuids = [uuid4() for _ in 1:nchains] + # Start the progress bars (but in reverse order, because + # ProgressLogging prints from the bottom up, and we want chain 1 to + # show up at the top) + for (i, uuid) in enumerate(reverse(uuids)) + ProgressLogging.@logprogress name = "Chain $(nchains-i+1)/$nchains" nothing _id = + uuid + end Distributed.@sync begin if progress @@ -456,12 +499,21 @@ function mcmcsample( Random.seed!(_rng, seeds[chainidx]) # Sample a chain and save it to the vector. + child_progressname = "Chain $chainidx/$nchains" + child_progress = if progress == false + false + else + ProgressLogging.Progress( + uuids[chainidx]; name=child_progressname + ) + end chains[chainidx] = StatsBase.sample( _rng, _model, _sampler, N; - progress=false, + progress=child_progress, + progressname=child_progressname, initial_params=if initial_params === nothing nothing else @@ -475,13 +527,20 @@ function mcmcsample( kwargs..., ) + ProgressLogging.@logprogress name = child_progressname "done" _id = uuids[chainidx] + # Update the progress bar. progress && put!(channel, true) end end finally - # Stop updating the progress bar. + # Stop updating the progress bars (either if sampling is done, or if + # an error occurs). progress && put!(channel, false) + for (i, uuid) in enumerate(reverse(uuids)) + ProgressLogging.@logprogress name = "Chain $(nchains-i+1)/$nchains" "done" _id = + uuid + end end end end