From b3434ac69c75b7fb90796d229602c5a0495a5c6d Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Thu, 26 Jun 2025 23:34:09 +0100 Subject: [PATCH 01/25] [wip] fix parallel sampling --- src/logging.jl | 2 +- src/sample.jl | 43 ++++++++++++++++++++++++++++++++----------- 2 files changed, 33 insertions(+), 12 deletions(-) diff --git a/src/logging.jl b/src/logging.jl index 04c41187..7abeec76 100644 --- a/src/logging.jl +++ b/src/logging.jl @@ -4,7 +4,7 @@ macro ifwithprogresslogger(progress, exprs...) return esc( quote - if $progress + if $progress == true if $hasprogresslevel($Logging.current_logger()) $ProgressLogging.@withprogress $(exprs...) else diff --git a/src/sample.jl b/src/sample.jl index 01a2006a..74ac71f8 100644 --- a/src/sample.jl +++ b/src/sample.jl @@ -144,7 +144,7 @@ function mcmcsample( @ifwithprogresslogger progress name = progressname begin # Determine threshold values for progress logging # (one update per 0.5% of progress) - if progress + if (progress == true || progress === nothing) threshold = Ntotal ÷ 200 next_update = threshold end @@ -166,8 +166,12 @@ function mcmcsample( # Update the progress bar. itotal = 1 - if progress && itotal >= next_update - ProgressLogging.@logprogress itotal / Ntotal + if !(progress == false) && itotal >= next_update + if progress == true + ProgressLogging.@logprogress itotal / Ntotal + else + ProgressLogging.@logprogress itotal / Ntotal _id = "hello" + end next_update = itotal + threshold end @@ -181,8 +185,12 @@ function mcmcsample( end # Update the progress bar. - if progress && (itotal += 1) >= next_update - ProgressLogging.@logprogress itotal / Ntotal + if !(progress == false) && (itotal += 1) >= next_update + if progress == true + ProgressLogging.@logprogress itotal / Ntotal + else + ProgressLogging.@logprogress itotal / Ntotal _id = "hello" + end next_update = itotal + threshold end end @@ -206,8 +214,12 @@ function mcmcsample( end # Update progress bar. - if progress && (itotal += 1) >= next_update - ProgressLogging.@logprogress itotal / Ntotal + if !(progress == false) && (itotal += 1) >= next_update + if progress == true + ProgressLogging.@logprogress itotal / Ntotal + else + ProgressLogging.@logprogress itotal / Ntotal _id = "hello" + end next_update = itotal + threshold end end @@ -227,8 +239,12 @@ function mcmcsample( samples = save!!(samples, sample, i, model, sampler, N; kwargs...) # Update the progress bar. - if progress && (itotal += 1) >= next_update - ProgressLogging.@logprogress itotal / Ntotal + if !(progress == false) && (itotal += 1) >= next_update + if progress == true + ProgressLogging.@logprogress itotal / Ntotal + else + ProgressLogging.@logprogress itotal / Ntotal _id = "hello" + end next_update = itotal + threshold end end @@ -456,12 +472,17 @@ function mcmcsample( Random.seed!(_rng, seeds[chainidx]) # Sample a chain and save it to the vector. - chains[chainidx] = StatsBase.sample( + child_progress = if progress == false + false + else + nothing + end + @ifwithprogresslogger progress chains[chainidx] = StatsBase.sample( _rng, _model, _sampler, N; - progress=false, + progress=child_progress, initial_params=if initial_params === nothing nothing else From 15250c7e6f88c3d3a67c8d4dd85671ebd460b349 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Fri, 27 Jun 2025 12:40:55 +0100 Subject: [PATCH 02/25] Parallel sampling with ProgressLogging --- Project.toml | 2 ++ src/sample.jl | 49 ++++++++++++++++++++++++++++++++++++++++++------- 2 files changed, 44 insertions(+), 7 deletions(-) diff --git a/Project.toml b/Project.toml index eef2aaf5..af292342 100644 --- a/Project.toml +++ b/Project.toml @@ -18,6 +18,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" TerminalLoggers = "5d786b92-1e48-4d6f-9151-6b4477ca9bed" Transducers = "28d57a85-8fef-5791-bfe6-a80928e7c999" +UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" [compat] BangBang = "0.3.19, 0.4" @@ -29,6 +30,7 @@ ProgressLogging = "0.1" StatsBase = "0.32, 0.33, 0.34" TerminalLoggers = "0.1" Transducers = "0.4.30" +UUIDs = "1.11.0" julia = "1.6" [extras] diff --git a/src/sample.jl b/src/sample.jl index 74ac71f8..d90a470a 100644 --- a/src/sample.jl +++ b/src/sample.jl @@ -1,3 +1,5 @@ +using UUIDs: uuid4 + # Default implementations of `sample`. const PROGRESS = Ref(true) @@ -144,11 +146,22 @@ function mcmcsample( @ifwithprogresslogger progress name = progressname begin # Determine threshold values for progress logging # (one update per 0.5% of progress) - if (progress == true || progress === nothing) + if !(progress == false) threshold = Ntotal ÷ 200 next_update = threshold end + # Ugly hacky code to reset the start timer if called from a multi-chain + # sampling process + if progress isa ProgressLogging.Progress + try + bartrees = Logging.current_logger().loggers[1].logger.bartrees + bar = TerminalLoggers.findbar(bartrees, progress.id).data + bar.tfirst = time() + catch + end + end + # Obtain the initial sample and state. sample, state = if num_warmup > 0 if initial_state === nothing @@ -170,7 +183,8 @@ function mcmcsample( if progress == true ProgressLogging.@logprogress itotal / Ntotal else - ProgressLogging.@logprogress itotal / Ntotal _id = "hello" + ProgressLogging.@logprogress name = progressname itotal / Ntotal _id = + progress.id end next_update = itotal + threshold end @@ -189,7 +203,8 @@ function mcmcsample( if progress == true ProgressLogging.@logprogress itotal / Ntotal else - ProgressLogging.@logprogress itotal / Ntotal _id = "hello" + ProgressLogging.@logprogress name = progressname itotal / Ntotal _id = + progress.id end next_update = itotal + threshold end @@ -218,7 +233,8 @@ function mcmcsample( if progress == true ProgressLogging.@logprogress itotal / Ntotal else - ProgressLogging.@logprogress itotal / Ntotal _id = "hello" + ProgressLogging.@logprogress name = progressname itotal / Ntotal _id = + progress.id end next_update = itotal + threshold end @@ -243,7 +259,8 @@ function mcmcsample( if progress == true ProgressLogging.@logprogress itotal / Ntotal else - ProgressLogging.@logprogress itotal / Ntotal _id = "hello" + ProgressLogging.@logprogress name = progressname itotal / Ntotal _id = + progress.id end next_update = itotal + threshold end @@ -432,6 +449,18 @@ function mcmcsample( if progress channel = Channel{Bool}(length(interval)) end + # Generate nchains independent UUIDs for each progress bar + uuids = [uuid4() for _ in 1:nchains] + # Start the progress bars (but in reverse order, because + # ProgressLogging prints from the bottom up, and we want chain 1 to + # show up at the top) + # TODO: This has an unintended effect that the 'time' field in the + # progress bar shows the total time since the beginning of sampling, + # even if the specific chain doesn't start sampling until later on. + for (i, uuid) in enumerate(reverse(uuids)) + ProgressLogging.@logprogress name = "Chain $(nchains-i+1)/$nchains" nothing _id = + uuid + end Distributed.@sync begin if progress @@ -472,17 +501,21 @@ function mcmcsample( Random.seed!(_rng, seeds[chainidx]) # Sample a chain and save it to the vector. + child_progressname = "Chain $chainidx/$nchains" child_progress = if progress == false false else - nothing + ProgressLogging.Progress( + uuids[chainidx]; name=child_progressname + ) end - @ifwithprogresslogger progress chains[chainidx] = StatsBase.sample( + chains[chainidx] = StatsBase.sample( _rng, _model, _sampler, N; progress=child_progress, + progressname=child_progressname, initial_params=if initial_params === nothing nothing else @@ -496,6 +529,8 @@ function mcmcsample( kwargs..., ) + ProgressLogging.@logprogress name = child_progressname "done" _id = uuids[chainidx] + # Update the progress bar. progress && put!(channel, true) end From 367718be3a87f866e7c06cdd8e3b3a51ac807584 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Fri, 27 Jun 2025 12:46:22 +0100 Subject: [PATCH 03/25] destroy per-chain progress bars if an error occurs --- src/sample.jl | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/sample.jl b/src/sample.jl index d90a470a..719ea7de 100644 --- a/src/sample.jl +++ b/src/sample.jl @@ -454,9 +454,6 @@ function mcmcsample( # Start the progress bars (but in reverse order, because # ProgressLogging prints from the bottom up, and we want chain 1 to # show up at the top) - # TODO: This has an unintended effect that the 'time' field in the - # progress bar shows the total time since the beginning of sampling, - # even if the specific chain doesn't start sampling until later on. for (i, uuid) in enumerate(reverse(uuids)) ProgressLogging.@logprogress name = "Chain $(nchains-i+1)/$nchains" nothing _id = uuid @@ -536,8 +533,13 @@ function mcmcsample( end end finally - # Stop updating the progress bar. + # Stop updating the progress bars (either if sampling is done, or if + # an error occurs). progress && put!(channel, false) + for (i, uuid) in enumerate(reverse(uuids)) + ProgressLogging.@logprogress name = "Chain $(nchains-i+1)/$nchains" "done" _id = + uuid + end end end end From 60134f8c56701643f06fbbce59c3cea33ccd9e84 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Fri, 27 Jun 2025 12:46:54 +0100 Subject: [PATCH 04/25] add a todo --- src/sample.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/sample.jl b/src/sample.jl index 719ea7de..10da7e6f 100644 --- a/src/sample.jl +++ b/src/sample.jl @@ -153,6 +153,7 @@ function mcmcsample( # Ugly hacky code to reset the start timer if called from a multi-chain # sampling process + # TODO: How to make this better? if progress isa ProgressLogging.Progress try bartrees = Logging.current_logger().loggers[1].logger.bartrees From e0ae51351e26d0e16030b8bfa9de8e44d78efc00 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Mon, 30 Jun 2025 17:58:56 +0100 Subject: [PATCH 05/25] Fix implementation --- src/AbstractMCMC.jl | 1 + src/logging.jl | 40 ++++++++- src/sample.jl | 198 +++++++++++++++++++++++++------------------- 3 files changed, 151 insertions(+), 88 deletions(-) diff --git a/src/AbstractMCMC.jl b/src/AbstractMCMC.jl index 8c1d5610..0e046927 100644 --- a/src/AbstractMCMC.jl +++ b/src/AbstractMCMC.jl @@ -13,6 +13,7 @@ using FillArrays: FillArrays using Distributed: Distributed using Logging: Logging using Random: Random +using UUIDs: UUIDs # Reexport sample using StatsBase: sample diff --git a/src/logging.jl b/src/logging.jl index 7abeec76..8de1b735 100644 --- a/src/logging.jl +++ b/src/logging.jl @@ -1,10 +1,12 @@ # avoid creating a progress bar with @withprogress if progress logging is disabled # and add a custom progress logger if the current logger does not seem to be able to handle # progress logs -macro ifwithprogresslogger(progress, exprs...) +macro single_ifwithprogresslogger(progress, exprs...) return esc( quote if $progress == true + # If progress == true, then we want to create a new logger. Note that + # progress might not be a Bool. if $hasprogresslevel($Logging.current_logger()) $ProgressLogging.@withprogress $(exprs...) else @@ -13,12 +15,48 @@ macro ifwithprogresslogger(progress, exprs...) end end else + # otherwise, progress isa UUID, or a channel, or false, in + # which case we don't want to create a new logger. $(exprs[end]) end end, ) end +# TODO(penelopeysm): figure out how to not have so much code duplication +macro multi_ifwithprogresslogger(progress, exprs...) + return esc( + quote + if $progress != :none + if $hasprogresslevel($Logging.current_logger()) + $ProgressLogging.@withprogress $(exprs...) + else + $with_progresslogger($Base.@__MODULE__, $Logging.current_logger()) do + $ProgressLogging.@withprogress $(exprs...) + end + end + else + $(exprs[end]) + end + end, + ) +end + +macro log_progress_dispatch(progress, progressname, progress_frac) + return esc( + quote + if $progress == true + $ProgressLogging.@logprogress $progress_frac + elseif $progress isa $UUIDs.UUID + $ProgressLogging.@logprogress $progressname $progress_frac _id = $progress + else + # progress == false, or progress isa Channel, which is handled manually + nothing + end + end, + ) +end + # improved checks? function hasprogresslevel(logger) return Logging.min_enabled_level(logger) ≤ ProgressLogging.ProgressLevel diff --git a/src/sample.jl b/src/sample.jl index 10da7e6f..70d013f8 100644 --- a/src/sample.jl +++ b/src/sample.jl @@ -1,5 +1,3 @@ -using UUIDs: uuid4 - # Default implementations of `sample`. const PROGRESS = Ref(true) @@ -113,7 +111,7 @@ function mcmcsample( model::AbstractModel, sampler::AbstractSampler, N::Integer; - progress=PROGRESS[], + progress::Union{Bool,UUIDs.UUID,Channel{Bool}}=PROGRESS[], progressname="Sampling", callback=nothing, num_warmup::Int=0, @@ -143,21 +141,21 @@ function mcmcsample( start = time() local state - @ifwithprogresslogger progress name = progressname begin + @single_ifwithprogresslogger progress name = progressname begin # Determine threshold values for progress logging # (one update per 0.5% of progress) - if !(progress == false) - threshold = Ntotal ÷ 200 - next_update = threshold - end - - # Ugly hacky code to reset the start timer if called from a multi-chain - # sampling process - # TODO: How to make this better? - if progress isa ProgressLogging.Progress + threshold = Ntotal ÷ 200 + next_update = threshold + + # Slightly hacky code to reset the start timer if called from a + # multi-chain sampling process. We need this because the progress bar + # is constructed in the multi-chain method, i.e. if we don't do this + # the progress bar shows the time elapsed since _all_ sampling began, + # not since the current chain started. + if progress isa UUIDs.UUID try bartrees = Logging.current_logger().loggers[1].logger.bartrees - bar = TerminalLoggers.findbar(bartrees, progress.id).data + bar = TerminalLoggers.findbar(bartrees, progress).data bar.tfirst = time() catch end @@ -178,17 +176,13 @@ function mcmcsample( end end - # Update the progress bar. + # Start the progress bar. itotal = 1 - if !(progress == false) && itotal >= next_update - if progress == true - ProgressLogging.@logprogress itotal / Ntotal - else - ProgressLogging.@logprogress name = progressname itotal / Ntotal _id = - progress.id - end + if itotal >= next_update + @log_progress_dispatch progress progressname itotal / Ntotal next_update = itotal + threshold end + progress isa Channel{Bool} && put!(progress, true) # Discard initial samples. for j in 1:discard_initial @@ -200,13 +194,9 @@ function mcmcsample( end # Update the progress bar. - if !(progress == false) && (itotal += 1) >= next_update - if progress == true - ProgressLogging.@logprogress itotal / Ntotal - else - ProgressLogging.@logprogress name = progressname itotal / Ntotal _id = - progress.id - end + itotal += 1 + if itotal >= next_update + @log_progress_dispatch progress progressname itotal / Ntotal next_update = itotal + threshold end end @@ -230,13 +220,9 @@ function mcmcsample( end # Update progress bar. - if !(progress == false) && (itotal += 1) >= next_update - if progress == true - ProgressLogging.@logprogress itotal / Ntotal - else - ProgressLogging.@logprogress name = progressname itotal / Ntotal _id = - progress.id - end + itotal += 1 + if itotal >= next_update + @log_progress_dispatch progress progressname itotal / Ntotal next_update = itotal + threshold end end @@ -256,15 +242,12 @@ function mcmcsample( samples = save!!(samples, sample, i, model, sampler, N; kwargs...) # Update the progress bar. - if !(progress == false) && (itotal += 1) >= next_update - if progress == true - ProgressLogging.@logprogress itotal / Ntotal - else - ProgressLogging.@logprogress name = progressname itotal / Ntotal _id = - progress.id - end + itotal += 1 + if itotal >= next_update + @log_progress_dispatch progress progressname itotal / Ntotal next_update = itotal + threshold end + progress isa Channel{Bool} && put!(progress, true) end end @@ -316,7 +299,7 @@ function mcmcsample( start = time() local state - @ifwithprogresslogger progress name = progressname begin + @single_ifwithprogresslogger progress name = progressname begin # Obtain the initial sample and state. sample, state = if num_warmup > 0 if initial_state === nothing @@ -423,6 +406,14 @@ function mcmcsample( @warn "Number of chains ($nchains) is greater than number of samples per chain ($N)" end + # Determine default progress bar style. + if progress == true + progress = nchains > 10 ? :overall : :perchain + elseif progress == false + progress = :none + end + # By this point, `progress` should be a Symbol, one of `:overall`, `:perchain`, or `:none`. + # Copy the random number generator, model, and sample for each thread nchunks = min(nchains, Threads.nthreads()) interval = 1:nchunks @@ -445,36 +436,44 @@ function mcmcsample( # Set up a chains vector. chains = Vector{Any}(undef, nchains) - @ifwithprogresslogger progress name = progressname begin - # Create a channel for progress logging. - if progress - channel = Channel{Bool}(length(interval)) - end - # Generate nchains independent UUIDs for each progress bar - uuids = [uuid4() for _ in 1:nchains] - # Start the progress bars (but in reverse order, because - # ProgressLogging prints from the bottom up, and we want chain 1 to - # show up at the top) - for (i, uuid) in enumerate(reverse(uuids)) - ProgressLogging.@logprogress name = "Chain $(nchains-i+1)/$nchains" nothing _id = - uuid + @multi_ifwithprogresslogger progress name = progressname begin + if progress == :perchain + # This is the 'overall' progress bar. We create a channel for each + # chain to report back to when it finishes sampling. + progress_channel = Channel{Bool}() + # These are the per-chain progress bars. We generate `nchains` + # independent UUIDs for each progress bar + uuids = [UUIDs.uuid4() for _ in 1:nchains] + progress_names = ["Chain $i/$nchains" for i in 1:nchains] + # Start the per-chain progress bars (but in reverse order, because + # ProgressLogging prints from the bottom up, and we want chain 1 to + # show up at the top) + for (progress_name, uuid) in reverse(collect(zip(progress_names, uuids))) + ProgressLogging.@logprogress name = progress_name nothing _id = uuid + end + elseif progress == :overall + # Just a single progress bar for the entire sampling, but instead + # of tracking each chain as it comes in, we track each sample as it + # comes in. This allows us to have more granular progress updates. + progress_channel = Channel{Bool}() end Distributed.@sync begin - if progress - # Update the progress bar. + if progress != :none + # This task updates the progress bar Distributed.@async begin # Determine threshold values for progress logging # (one update per 0.5% of progress) - threshold = nchains ÷ 200 - nextprogresschains = threshold - - progresschains = 0 - while take!(channel) - progresschains += 1 - if progresschains >= nextprogresschains - ProgressLogging.@logprogress progresschains / nchains - nextprogresschains = progresschains + threshold + Ntotal = progress == :overall ? nchains * N : nchains + threshold = Ntotal ÷ 200 + next_update = threshold + + itotal = 0 + while take!(progress_channel) + itotal += 1 + if itotal >= next_update + ProgressLogging.@logprogress itotal / Ntotal + next_update = itotal + threshold end end end @@ -498,15 +497,23 @@ function mcmcsample( # Seed the chunk-specific random number generator with the pre-made seed. Random.seed!(_rng, seeds[chainidx]) - # Sample a chain and save it to the vector. - child_progressname = "Chain $chainidx/$nchains" - child_progress = if progress == false - false - else - ProgressLogging.Progress( - uuids[chainidx]; name=child_progressname - ) + # Determine how to monitor progress for the child chains. + child_progress, child_progressname = if progress == :none + # No need to create a progress bar + false, "" + elseif progress == :overall + # No need to create a new progress bar, but we need to + # pass the channel to the child so that it can log when + # it has finished obtaining each sample. + progress_channel, "" + elseif progress == :perchain + # We need to specify both the ID of the progress bar for + # the child to update, and we also specify the name to use + # for the progress bar. + uuids[chainidx], progress_names[chainidx] end + + # Sample a chain and save it to the vector. chains[chainidx] = StatsBase.sample( _rng, _model, @@ -527,19 +534,36 @@ function mcmcsample( kwargs..., ) - ProgressLogging.@logprogress name = child_progressname "done" _id = uuids[chainidx] - - # Update the progress bar. - progress && put!(channel, true) + # Update the progress bars. + if progress == :perchain + # Tell the 'main' progress bar that this chain is done. + put!(progress_channel, true) + # Conclude the per-chain progress bar. + ProgressLogging.@logprogress progress_names[chainidx] "done" _id = uuids[chainidx] + end + # Note that if progress == :overall, we don't need to do anything + # because progress on that bar is triggered by + # samples being obtained rather than chains being + # completed. end end finally - # Stop updating the progress bars (either if sampling is done, or if - # an error occurs). - progress && put!(channel, false) - for (i, uuid) in enumerate(reverse(uuids)) - ProgressLogging.@logprogress name = "Chain $(nchains-i+1)/$nchains" "done" _id = - uuid + if progress == :perchain + # Stop updating the main progress bar (either if sampling + # is done, or if an error occurs). + put!(progress_channel, false) + # Additionally stop the per-chain progress bars (but in + # reverse order, because ProgressLogging prints from + # the bottom up, and we want chain 1 to show up at the + # top) + for (progress_name, uuid) in + reverse(collect(zip(progress_names, uuids))) + ProgressLogging.@logprogress progress_name "done" _id = uuid + end + elseif progress == :overall + # Stop updating the main progress bar (either if sampling + # is done, or if an error occurs). + put!(progress_channel, false) end end end @@ -589,7 +613,7 @@ function mcmcsample( pool = Distributed.CachingPool(Distributed.workers()) local chains - @ifwithprogresslogger progress name = progressname begin + @single_ifwithprogresslogger progress name = progressname begin # Create a channel for progress logging. if progress channel = Distributed.RemoteChannel(() -> Channel{Bool}(Distributed.nworkers())) From 6b514e45d844061774abcb86542ad74d01a96549 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Mon, 30 Jun 2025 17:59:22 +0100 Subject: [PATCH 06/25] Bump minor version --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index af292342..e716627e 100644 --- a/Project.toml +++ b/Project.toml @@ -3,7 +3,7 @@ uuid = "80f14c24-f653-4e6a-9b94-39d6b0f70001" keywords = ["markov chain monte carlo", "probabilistic programming"] license = "MIT" desc = "A lightweight interface for common MCMC methods." -version = "5.6.3" +version = "5.7.0" [deps] BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66" From bbda3c84d19a636af1d8e6c0536461afcc691f85 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Mon, 30 Jun 2025 18:19:20 +0100 Subject: [PATCH 07/25] Add `setmaxchainsprogress!` --- src/sample.jl | 22 +++++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/src/sample.jl b/src/sample.jl index 70d013f8..f5ba18b5 100644 --- a/src/sample.jl +++ b/src/sample.jl @@ -1,5 +1,6 @@ # Default implementations of `sample`. const PROGRESS = Ref(true) +const MAX_CHAINS_PROGRESS = Ref(10) _pluralise(n; singular="", plural="s") = n == 1 ? singular : plural @@ -17,6 +18,25 @@ function setprogress!(progress::Bool; silent::Bool=false) return progress end +""" + setmaxchainsprogress!(max_chains::Int, silent::Bool=false) + +Set the maximum number of chains to display progress bars for when sampling +multiple chains at once (if progress logging is enabled). Above this limit, no +progress bars are displayed for individual chains; instead, a single progress +bar is displayed for the entire sampling process. +""" +function setmaxchainsprogress!(max_chains::Int, silent::Bool=false) + if max_chains < 0 + throw(ArgumentError("maximum number of chains must be non-negative")) + end + if !silent + @info "AbstractMCMC: maximum number of per-chain progress bars set to $max_chains" + end + MAX_CHAINS_PROGRESS[] = max_chains + return nothing +end + function StatsBase.sample( model_or_logdensity, sampler::AbstractSampler, N_or_isdone; kwargs... ) @@ -408,7 +428,7 @@ function mcmcsample( # Determine default progress bar style. if progress == true - progress = nchains > 10 ? :overall : :perchain + progress = nchains > MAX_CHAINS_PROGRESS[] ? :overall : :perchain elseif progress == false progress = :none end From a9e530628c585a67a69c0cd208637edb6c13d356 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Mon, 30 Jun 2025 18:24:41 +0100 Subject: [PATCH 08/25] Don't duplicate macro --- src/logging.jl | 26 ++++---------------------- src/sample.jl | 11 +++++++---- 2 files changed, 11 insertions(+), 26 deletions(-) diff --git a/src/logging.jl b/src/logging.jl index 8de1b735..5339a98a 100644 --- a/src/logging.jl +++ b/src/logging.jl @@ -1,10 +1,10 @@ # avoid creating a progress bar with @withprogress if progress logging is disabled # and add a custom progress logger if the current logger does not seem to be able to handle # progress logs -macro single_ifwithprogresslogger(progress, exprs...) +macro ifwithprogresslogger(cond, exprs...) return esc( quote - if $progress == true + if $cond # If progress == true, then we want to create a new logger. Note that # progress might not be a Bool. if $hasprogresslevel($Logging.current_logger()) @@ -23,25 +23,6 @@ macro single_ifwithprogresslogger(progress, exprs...) ) end -# TODO(penelopeysm): figure out how to not have so much code duplication -macro multi_ifwithprogresslogger(progress, exprs...) - return esc( - quote - if $progress != :none - if $hasprogresslevel($Logging.current_logger()) - $ProgressLogging.@withprogress $(exprs...) - else - $with_progresslogger($Base.@__MODULE__, $Logging.current_logger()) do - $ProgressLogging.@withprogress $(exprs...) - end - end - else - $(exprs[end]) - end - end, - ) -end - macro log_progress_dispatch(progress, progressname, progress_frac) return esc( quote @@ -50,7 +31,8 @@ macro log_progress_dispatch(progress, progressname, progress_frac) elseif $progress isa $UUIDs.UUID $ProgressLogging.@logprogress $progressname $progress_frac _id = $progress else - # progress == false, or progress isa Channel, which is handled manually + # progress == false, or progress isa Channel, both of which are + # handled manually nothing end end, diff --git a/src/sample.jl b/src/sample.jl index f5ba18b5..e9538f5c 100644 --- a/src/sample.jl +++ b/src/sample.jl @@ -161,7 +161,10 @@ function mcmcsample( start = time() local state - @single_ifwithprogresslogger progress name = progressname begin + # Only create a new progress bar if progress is explicitly equal to true, i.e. + # it's not a UUID (the progress bar already exists), a channel (there's no need + # for a new progress bar), or false (no progress bar). + @ifwithprogresslogger (progress == true) name = progressname begin # Determine threshold values for progress logging # (one update per 0.5% of progress) threshold = Ntotal ÷ 200 @@ -319,7 +322,7 @@ function mcmcsample( start = time() local state - @single_ifwithprogresslogger progress name = progressname begin + @ifwithprogresslogger (progress == true) name = progressname begin # Obtain the initial sample and state. sample, state = if num_warmup > 0 if initial_state === nothing @@ -456,7 +459,7 @@ function mcmcsample( # Set up a chains vector. chains = Vector{Any}(undef, nchains) - @multi_ifwithprogresslogger progress name = progressname begin + @ifwithprogresslogger (progress != :none) name = progressname begin if progress == :perchain # This is the 'overall' progress bar. We create a channel for each # chain to report back to when it finishes sampling. @@ -633,7 +636,7 @@ function mcmcsample( pool = Distributed.CachingPool(Distributed.workers()) local chains - @single_ifwithprogresslogger progress name = progressname begin + @ifwithprogresslogger (progress == true) name = progressname begin # Create a channel for progress logging. if progress channel = Distributed.RemoteChannel(() -> Channel{Bool}(Distributed.nworkers())) From a03692d381eb40aef0cd0a12a20f0422fba85006 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Mon, 30 Jun 2025 18:54:32 +0100 Subject: [PATCH 09/25] :overall works with MCMCDistributed now --- src/sample.jl | 133 ++++++++++++++++++++++++++++++++++++++++---------- 1 file changed, 106 insertions(+), 27 deletions(-) diff --git a/src/sample.jl b/src/sample.jl index e9538f5c..da4cd8da 100644 --- a/src/sample.jl +++ b/src/sample.jl @@ -131,7 +131,7 @@ function mcmcsample( model::AbstractModel, sampler::AbstractSampler, N::Integer; - progress::Union{Bool,UUIDs.UUID,Channel{Bool}}=PROGRESS[], + progress::Union{Bool,UUIDs.UUID,Channel{Bool},Distributed.RemoteChannel{Channel{Bool}}}=PROGRESS[], progressname="Sampling", callback=nothing, num_warmup::Int=0, @@ -205,7 +205,10 @@ function mcmcsample( @log_progress_dispatch progress progressname itotal / Ntotal next_update = itotal + threshold end - progress isa Channel{Bool} && put!(progress, true) + if progress isa Channel{Bool} || + progress isa Distributed.RemoteChannel{Channel{Bool}} + put!(progress, true) + end # Discard initial samples. for j in 1:discard_initial @@ -270,7 +273,10 @@ function mcmcsample( @log_progress_dispatch progress progressname itotal / Ntotal next_update = itotal + threshold end - progress isa Channel{Bool} && put!(progress, true) + if progress isa Channel{Bool} || + progress isa Distributed.RemoteChannel{Channel{Bool}} + put!(progress, true) + end end end @@ -413,7 +419,7 @@ function mcmcsample( ::MCMCThreads, N::Integer, nchains::Integer; - progress=PROGRESS[], + progress::Union{Bool,Symbol}=PROGRESS[], progressname="Sampling ($(min(nchains, Threads.nthreads())) thread$(_pluralise(min(nchains, Threads.nthreads()))))", initial_params=nothing, initial_state=nothing, @@ -604,7 +610,7 @@ function mcmcsample( ::MCMCDistributed, N::Integer, nchains::Integer; - progress=PROGRESS[], + progress::Union{Bool,Symbol}=PROGRESS[], progressname="Sampling ($(Distributed.nworkers()) process$(_pluralise(Distributed.nworkers(); plural="es")))", initial_params=nothing, initial_state=nothing, @@ -620,6 +626,14 @@ function mcmcsample( @warn "Number of chains ($nchains) is greater than number of samples per chain ($N)" end + # Determine default progress bar style. + if progress == true + progress = nchains > MAX_CHAINS_PROGRESS[] ? :overall : :perchain + elseif progress == false + progress = :none + end + # By this point, `progress` should be a Symbol, one of `:overall`, `:perchain`, or `:none`. + # Ensure that initial parameters and states are `nothing` or of the correct length check_initial_params(initial_params, nchains) check_initial_state(initial_state, nchains) @@ -636,27 +650,55 @@ function mcmcsample( pool = Distributed.CachingPool(Distributed.workers()) local chains - @ifwithprogresslogger (progress == true) name = progressname begin - # Create a channel for progress logging. - if progress - channel = Distributed.RemoteChannel(() -> Channel{Bool}(Distributed.nworkers())) + @ifwithprogresslogger (progress != :none) name = progressname begin + # Set up progress logging. + if progress == :perchain + # This is the 'overall' progress bar. We create a channel for each + # chain to report back to when it finishes sampling. + progress_channel = Distributed.RemoteChannel( + () -> Channel{Bool}(Distributed.nworkers()) + ) + # These are the per-chain progress bars. We generate `nchains` + # independent UUIDs for each progress bar + uuids = [UUIDs.uuid4() for _ in 1:nchains] + progress_names = ["Chain $i/$nchains" for i in 1:nchains] + # Start the per-chain progress bars (but in reverse order, because + # ProgressLogging prints from the bottom up, and we want chain 1 to + # show up at the top) + for (progress_name, uuid) in reverse(collect(zip(progress_names, uuids))) + ProgressLogging.@logprogress name = progress_name nothing _id = uuid + end + child_progresses = uuids + child_progressnames = progress_names + elseif progress == :overall + # Just a single progress bar for the entire sampling, but instead + # of tracking each chain as it comes in, we track each sample as it + # comes in. This allows us to have more granular progress updates. + chan = Channel{Bool}(Distributed.nworkers()) + progress_channel = Distributed.RemoteChannel(() -> chan) + child_progresses = [progress_channel for _ in 1:nchains] + child_progressnames = ["" for _ in 1:nchains] + elseif progress == :none + child_progresses = [false for _ in 1:nchains] + child_progressnames = ["" for _ in 1:nchains] end Distributed.@sync begin - if progress - # Update the progress bar. + if progress != :none + # This task updates the progress bar Distributed.@async begin # Determine threshold values for progress logging # (one update per 0.5% of progress) - threshold = nchains ÷ 200 - nextprogresschains = threshold - - progresschains = 0 - while take!(channel) - progresschains += 1 - if progresschains >= nextprogresschains - ProgressLogging.@logprogress progresschains / nchains - nextprogresschains = progresschains + threshold + Ntotal = progress == :overall ? nchains * N : nchains + threshold = Ntotal ÷ 200 + next_update = threshold + + itotal = 0 + while take!(progress_channel) + itotal += 1 + if itotal >= next_update + ProgressLogging.@logprogress itotal / Ntotal + next_update = itotal + threshold end end end @@ -664,7 +706,13 @@ function mcmcsample( Distributed.@async begin try - function sample_chain(seed, initial_params, initial_state) + function sample_chain( + seed, + initial_params, + initial_state, + child_progress, + child_progressname, + ) # Seed a new random number generator with the pre-made seed. Random.seed!(rng, seed) @@ -674,24 +722,55 @@ function mcmcsample( model, sampler, N; - progress=false, + progress=child_progress, + progressname=child_progressname, initial_params=initial_params, initial_state=initial_state, kwargs..., ) - # Update the progress bar. - progress && put!(channel, true) + # Update the progress bars. Note that the case of + # progress = :overall doesn't need to be handled here + # (for similar reasons to the MCMCThreads method + # above). + if progress == :perchain + # Tell the 'main' progress bar that this chain is done. + put!(progress_channel, true) + # Conclude the per-chain progress bar. + ProgressLogging.@logprogress child_progressname "done" _id = + child_progress + end # Return the new chain. return chain end chains = Distributed.pmap( - sample_chain, pool, seeds, _initial_params, _initial_state + sample_chain, + pool, + seeds, + _initial_params, + _initial_state, + child_progresses, + child_progressnames, ) finally - # Stop updating the progress bar. - progress && put!(channel, false) + if progress == :perchain + # Stop updating the main progress bar (either if sampling + # is done, or if an error occurs). + put!(progress_channel, false) + # Additionally stop the per-chain progress bars (but in + # reverse order, because ProgressLogging prints from + # the bottom up, and we want chain 1 to show up at the + # top) + for (progress_name, uuid) in + reverse(collect(zip(progress_names, uuids))) + ProgressLogging.@logprogress progress_name "done" _id = uuid + end + elseif progress == :overall + # Stop updating the main progress bar (either if sampling + # is done, or if an error occurs). + put!(progress_channel, false) + end end end end From 838db60ef5e2e77305993b0741793cde2901d65f Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Mon, 30 Jun 2025 23:15:40 +0100 Subject: [PATCH 10/25] Give up on :perchain for MCMCDistributed --- src/sample.jl | 63 ++++++++++++++++----------------------------------- 1 file changed, 19 insertions(+), 44 deletions(-) diff --git a/src/sample.jl b/src/sample.jl index da4cd8da..de8f56ab 100644 --- a/src/sample.jl +++ b/src/sample.jl @@ -441,7 +441,11 @@ function mcmcsample( elseif progress == false progress = :none end - # By this point, `progress` should be a Symbol, one of `:overall`, `:perchain`, or `:none`. + progress in [:overall, :perchain, :none] || throw( + ArgumentError( + "`progress` for MCMCThreads must be `:overall`, `:perchain`, `:none`, or a boolean", + ), + ) # Copy the random number generator, model, and sample for each thread nchunks = min(nchains, Threads.nthreads()) @@ -581,12 +585,8 @@ function mcmcsample( # Stop updating the main progress bar (either if sampling # is done, or if an error occurs). put!(progress_channel, false) - # Additionally stop the per-chain progress bars (but in - # reverse order, because ProgressLogging prints from - # the bottom up, and we want chain 1 to show up at the - # top) - for (progress_name, uuid) in - reverse(collect(zip(progress_names, uuids))) + # Additionally stop the per-chain progress bars + for (progress_name, uuid) in zip(progress_names, uuids) ProgressLogging.@logprogress progress_name "done" _id = uuid end elseif progress == :overall @@ -626,13 +626,18 @@ function mcmcsample( @warn "Number of chains ($nchains) is greater than number of samples per chain ($N)" end - # Determine default progress bar style. + # Determine default progress bar style. Note that for MCMCDistributed(), + # :perchain isn't implemented. if progress == true - progress = nchains > MAX_CHAINS_PROGRESS[] ? :overall : :perchain + progress = :overall elseif progress == false progress = :none end - # By this point, `progress` should be a Symbol, one of `:overall`, `:perchain`, or `:none`. + progress in [:overall, :none] || throw( + ArgumentError( + "`progress` for MCMCDistributed must be `:overall`, `:none`, or a boolean" + ), + ) # Ensure that initial parameters and states are `nothing` or of the correct length check_initial_params(initial_params, nchains) @@ -652,25 +657,7 @@ function mcmcsample( local chains @ifwithprogresslogger (progress != :none) name = progressname begin # Set up progress logging. - if progress == :perchain - # This is the 'overall' progress bar. We create a channel for each - # chain to report back to when it finishes sampling. - progress_channel = Distributed.RemoteChannel( - () -> Channel{Bool}(Distributed.nworkers()) - ) - # These are the per-chain progress bars. We generate `nchains` - # independent UUIDs for each progress bar - uuids = [UUIDs.uuid4() for _ in 1:nchains] - progress_names = ["Chain $i/$nchains" for i in 1:nchains] - # Start the per-chain progress bars (but in reverse order, because - # ProgressLogging prints from the bottom up, and we want chain 1 to - # show up at the top) - for (progress_name, uuid) in reverse(collect(zip(progress_names, uuids))) - ProgressLogging.@logprogress name = progress_name nothing _id = uuid - end - child_progresses = uuids - child_progressnames = progress_names - elseif progress == :overall + if progress == :overall # Just a single progress bar for the entire sampling, but instead # of tracking each chain as it comes in, we track each sample as it # comes in. This allows us to have more granular progress updates. @@ -684,12 +671,12 @@ function mcmcsample( end Distributed.@sync begin - if progress != :none + if progress == :overall # This task updates the progress bar Distributed.@async begin # Determine threshold values for progress logging # (one update per 0.5% of progress) - Ntotal = progress == :overall ? nchains * N : nchains + Ntotal = nchains * N threshold = Ntotal ÷ 200 next_update = threshold @@ -754,19 +741,7 @@ function mcmcsample( child_progressnames, ) finally - if progress == :perchain - # Stop updating the main progress bar (either if sampling - # is done, or if an error occurs). - put!(progress_channel, false) - # Additionally stop the per-chain progress bars (but in - # reverse order, because ProgressLogging prints from - # the bottom up, and we want chain 1 to show up at the - # top) - for (progress_name, uuid) in - reverse(collect(zip(progress_names, uuids))) - ProgressLogging.@logprogress progress_name "done" _id = uuid - end - elseif progress == :overall + if progress == :overall # Stop updating the main progress bar (either if sampling # is done, or if an error occurs). put!(progress_channel, false) From 6b59b21fe0d74d31b6956c38cc162858ec443111 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Mon, 30 Jun 2025 23:26:38 +0100 Subject: [PATCH 11/25] Fix comments --- src/logging.jl | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/logging.jl b/src/logging.jl index 5339a98a..edb83254 100644 --- a/src/logging.jl +++ b/src/logging.jl @@ -5,8 +5,7 @@ macro ifwithprogresslogger(cond, exprs...) return esc( quote if $cond - # If progress == true, then we want to create a new logger. Note that - # progress might not be a Bool. + # Create a new logger if $hasprogresslevel($Logging.current_logger()) $ProgressLogging.@withprogress $(exprs...) else @@ -15,8 +14,9 @@ macro ifwithprogresslogger(cond, exprs...) end end else - # otherwise, progress isa UUID, or a channel, or false, in - # which case we don't want to create a new logger. + # Don't create a new logger, either because progress logging + # was disabled, or because it's otherwise being manually + # managed. $(exprs[end]) end end, @@ -27,8 +27,10 @@ macro log_progress_dispatch(progress, progressname, progress_frac) return esc( quote if $progress == true + # Use global logger $ProgressLogging.@logprogress $progress_frac elseif $progress isa $UUIDs.UUID + # Use the logger with this specific UUID $ProgressLogging.@logprogress $progressname $progress_frac _id = $progress else # progress == false, or progress isa Channel, both of which are From b340ebc619ab3dea8eb477819537d7373f3add67 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Mon, 30 Jun 2025 23:26:46 +0100 Subject: [PATCH 12/25] Remove dead code --- src/logging.jl | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/src/logging.jl b/src/logging.jl index edb83254..75db332b 100644 --- a/src/logging.jl +++ b/src/logging.jl @@ -57,14 +57,3 @@ function with_progresslogger(f, _module, logger) return Logging.with_logger(f, LoggingExtras.TeeLogger(logger1, logger2)) end - -function progresslogger() - # detect if code is running under IJulia since TerminalLogger does not work with IJulia - # https://github.com/JuliaLang/IJulia.jl#detecting-that-code-is-running-under-ijulia - if (Sys.iswindows() && VERSION < v"1.5.3") || - (isdefined(Main, :IJulia) && Main.IJulia.inited) - return ConsoleProgressMonitor.ProgressLogger() - else - return TerminalLoggers.TerminalLogger() - end -end From 1195503fc700bf77e9f5bb8e471744961d8592b3 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Mon, 30 Jun 2025 23:29:57 +0100 Subject: [PATCH 13/25] Undelete some not-actually-dead code --- src/logging.jl | 10 ++++++++++ src/sample.jl | 4 ++-- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/src/logging.jl b/src/logging.jl index 75db332b..ed2fc0e4 100644 --- a/src/logging.jl +++ b/src/logging.jl @@ -57,3 +57,13 @@ function with_progresslogger(f, _module, logger) return Logging.with_logger(f, LoggingExtras.TeeLogger(logger1, logger2)) end + +function progresslogger() + # detect if code is running under IJulia since TerminalLogger does not work with IJulia + # https://github.com/JuliaLang/IJulia.jl#detecting-that-code-is-running-under-ijulia + if (isdefined(Main, :IJulia) && Main.IJulia.inited) + return ConsoleProgressMonitor.ProgressLogger() + else + return TerminalLoggers.TerminalLogger() + end +end diff --git a/src/sample.jl b/src/sample.jl index de8f56ab..ac4b0508 100644 --- a/src/sample.jl +++ b/src/sample.jl @@ -304,7 +304,7 @@ function mcmcsample( sampler::AbstractSampler, isdone; chain_type::Type=Any, - progress=PROGRESS[], + progress::Bool=PROGRESS[], progressname="Convergence sampling", callback=nothing, num_warmup=0, @@ -328,7 +328,7 @@ function mcmcsample( start = time() local state - @ifwithprogresslogger (progress == true) name = progressname begin + @ifwithprogresslogger progress name = progressname begin # Obtain the initial sample and state. sample, state = if num_warmup > 0 if initial_state === nothing From 594483fe8e00a57ac988135844d9ec9f614ac433 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Tue, 1 Jul 2025 00:03:59 +0100 Subject: [PATCH 14/25] Broaden UUIDs compat so that it works on older Julia versions --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index e716627e..66ada2dd 100644 --- a/Project.toml +++ b/Project.toml @@ -30,7 +30,7 @@ ProgressLogging = "0.1" StatsBase = "0.32, 0.33, 0.34" TerminalLoggers = "0.1" Transducers = "0.4.30" -UUIDs = "1.11.0" +UUIDs = "<0.0.1, 1" julia = "1.6" [extras] From 7def4b429c12fe08ba69336e38671854ad361ed9 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Tue, 1 Jul 2025 00:19:15 +0100 Subject: [PATCH 15/25] Explain progress logging in docs --- docs/src/api.md | 23 +++++++++++++++++++++-- 1 file changed, 21 insertions(+), 2 deletions(-) diff --git a/docs/src/api.md b/docs/src/api.md index da92fd37..daa75eab 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -68,7 +68,7 @@ AbstractMCMC.MCMCSerial ## Common keyword arguments Common keyword arguments for regular and parallel sampling are: -- `progress` (default: `AbstractMCMC.PROGRESS[]` which is `true` initially): toggles progress logging +- `progress` (default: `AbstractMCMC.PROGRESS[]` which is `true` initially): toggles progress logging. See the section on [Progress logging](#progress-logging) below for more details. - `chain_type` (default: `Any`): determines the type of the returned chain - `callback` (default: `nothing`): if `callback !== nothing`, then `callback(rng, model, sampler, sample, iteration)` is called after every sampling step, @@ -90,12 +90,31 @@ However, multiple packages such as [EllipticalSliceSampling.jl](https://github.c To ensure that sampling multiple chains "just works" when sampling of a single chain is implemented, [we decided to support `initial_params` in the default implementations of the ensemble methods](https://github.com/TuringLang/AbstractMCMC.jl/pull/94): - `initial_params` (default: `nothing`): if `initial_params isa AbstractArray`, then the `i`th element of `initial_params` is used as initial parameters of the `i`th chain. If one wants to use the same initial parameters `x` for every chain, one can specify e.g. `initial_params = FillArrays.Fill(x, N)`. -Progress logging can be enabled and disabled globally with `AbstractMCMC.setprogress!(progress)`. +## Progress logging + +The default value for the `progress` keyword argument is `AbstractMCMC.PROGRESS[]`, which is always set to `true` unless modified with `AbstractMCMC.setprogress!`. +For example, `setprogress!(false)` will disable all progress logging. ```@docs AbstractMCMC.setprogress! ``` +For single-chain sampling (i.e., `sample([rng,] model, sampler, N)`), as well as multiple-chain sampling with `MCMCSerial`, the `progress` keyword argument should be a `Bool`. + +For multiple-chain sampling using `MCMCThreads`, there are several, more detailed, options: + +- `:perchain`: create one progress bar per chain being sampled +- `:overall`: create one progress bar for the overall sampling process, which tracks the percentage of samples that have been sampled across all chains +- `:none`: do not create any progress bar +- `true` (the default): use `perchain` for 10 or fewer chains, and `overall` for more than 10 chains +- `false`: same as `none`, i.e. no progress bar + +The threshold of 10 chains can be changed using `AbstractMCMC.setmaxchainsprogress!(N)`, which will cause `MCMCThreads` to use `:perchain` for `N` or fewer chains, and `:overall` for more than `N` chains. +Thus, for example, if you _always_ want to use `:overall`, you can call `AbstractMCMC.setmaxchainsprogress!(0)`. + +Multiple-chain sampling using `MCMCDistributed` behaves the same as `MCMCThreads`, except that `:perchain` is not (yet?) implemented. +So, `true` always corresponds to `:overall`, and `false` corresponds to `:none`. + ## Chains The `chain_type` keyword argument allows to set the type of the returned chain. A common From 022678e61052ad151b9f3f63f07be68555726def Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Tue, 1 Jul 2025 01:08:53 +0100 Subject: [PATCH 16/25] Remove dead code --- test/sample.jl | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/test/sample.jl b/test/sample.jl index e036f558..5af2aa74 100644 --- a/test/sample.jl +++ b/test/sample.jl @@ -10,8 +10,7 @@ @test length(LOGGERS) == 1 logger = first(LOGGERS) @test logger isa TeeLogger - @test logger.loggers[1].logger isa - (Sys.iswindows() && VERSION < v"1.5.3" ? ProgressLogger : TerminalLogger) + @test logger.loggers[1].logger isa TerminalLogger @test logger.loggers[2].logger === CURRENT_LOGGER @test Logging.current_logger() === CURRENT_LOGGER From 5b2577fa4b73382b3880391b87e1c017bd73db16 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Tue, 1 Jul 2025 01:09:00 +0100 Subject: [PATCH 17/25] Fix channel buffering for MCMCThreads --- src/sample.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sample.jl b/src/sample.jl index ac4b0508..50753cee 100644 --- a/src/sample.jl +++ b/src/sample.jl @@ -488,7 +488,7 @@ function mcmcsample( # Just a single progress bar for the entire sampling, but instead # of tracking each chain as it comes in, we track each sample as it # comes in. This allows us to have more granular progress updates. - progress_channel = Channel{Bool}() + progress_channel = Channel{Bool}(nchains) end Distributed.@sync begin From cefafb0c7d5ff2a8e0ac9dce478b84d7d8dbe049 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Tue, 1 Jul 2025 14:03:02 +0100 Subject: [PATCH 18/25] Attempt to use proper types for logging --- src/logging.jl | 104 ++++++++++++++++++++++++++++++++++++++++--------- src/sample.jl | 44 +++++++++++---------- 2 files changed, 110 insertions(+), 38 deletions(-) diff --git a/src/logging.jl b/src/logging.jl index ed2fc0e4..29019da6 100644 --- a/src/logging.jl +++ b/src/logging.jl @@ -1,3 +1,89 @@ +""" + AbstractProgressKwarg + +Abstract type representing the values that the `progress` keyword argument can +internally take for single-chain sampling. +""" +abstract type AbstractProgressKwarg end + +""" + CreateNewProgressBar + +Create a new logger for progress logging. +""" +struct CreateNewProgressBar{S<:AbstractString} <: AbstractProgressKwarg + name::S + uuid::UUIDs.UUID + + function CreateNewProgressBar(name::AbstractString) + return new{typeof{name}}(name, UUIDs.uuid4()) + end +end +function init_progress(p::CreateNewProgressBar) + if hasprogresslevel(Logging.current_logger()) + ProgressLogging.@withprogress $(exprs...) + else + $with_progresslogger($Base.@__MODULE__, $Logging.current_logger()) do + $ProgressLogging.@withprogress $(exprs...) + end + end + ProgressLogging.@logprogress p.name nothing _id = p.uuid +end +function update_progress(p::CreateNewProgressBar, progress_frac, ::Bool) + ProgressLogging.@logprogress p.name progress_frac _id = p.uuid +end +finish_progress(::CreateNewProgressBar) = ProgressLogging.@logprogress "done" + +""" + NoLogging + +Do not log progress at all. +""" +struct NoLogging <: AbstractProgressKwarg end +init_progress(::NoLogging) = nothing +update_progress(::NoLogging, ::Any, ::Bool) = nothing +finish_progress(::NoLogging) = nothing + +""" + ExistingProgressBar + +Use an existing progress bar to log progress. This is used for tracking +progress in a progress bar that has been previously generated elsewhere, +specifically, when `sample(..., MCMCThreads(), ...; progress=:perchain)` is +called. In this case we can use `@logprogress name progress_frac _id = uuid` to +log progress. +""" +struct ExistingProgressBar{S<:AbstractString} <: AbstractProgressKwarg + name::S + uuid::UUIDs.UUID +end +init_progress(::ExistingProgressBar) = nothing +function update_progress(p::ExistingProgressBar, progress_frac, ::Bool) + ProgressLogging.@logprogress p.name progress_frac _id = p.uuid +end +function finish_progress(p::ExistingProgressBar) + ProgressLogging.@logprogress p.name "done" _id = p.uuid +end + +""" + ChannelProgress + +Use a `Channel` to log progress. This is used for 'reporting' progress back +to the main thread or worker when using `progress=:overall` with MCMCThreads or +MCMCDistributed. +""" +struct ChannelProgress{T<:Union{Channel{Bool},Distributed.RemoteChannel{Channel{Bool}}}} <: + AbstractProgressKwarg + channel::T +end +init_progress(::ChannelProgress) = nothing +function update_progress(p::ChannelProgress, ::Any, update_channel::Bool) + return update_channel && put!(p.channel, true) +end +# Note: We don't want to `put!(p.channel, false)`, because that would stop the +# channel from being used for further updates e.g. from other chains. +finish_progress(::ChannelProgress) = nothing + # avoid creating a progress bar with @withprogress if progress logging is disabled # and add a custom progress logger if the current logger does not seem to be able to handle # progress logs @@ -23,24 +109,6 @@ macro ifwithprogresslogger(cond, exprs...) ) end -macro log_progress_dispatch(progress, progressname, progress_frac) - return esc( - quote - if $progress == true - # Use global logger - $ProgressLogging.@logprogress $progress_frac - elseif $progress isa $UUIDs.UUID - # Use the logger with this specific UUID - $ProgressLogging.@logprogress $progressname $progress_frac _id = $progress - else - # progress == false, or progress isa Channel, both of which are - # handled manually - nothing - end - end, - ) -end - # improved checks? function hasprogresslevel(logger) return Logging.min_enabled_level(logger) ≤ ProgressLogging.ProgressLevel diff --git a/src/sample.jl b/src/sample.jl index 50753cee..bb2b4f4b 100644 --- a/src/sample.jl +++ b/src/sample.jl @@ -131,7 +131,7 @@ function mcmcsample( model::AbstractModel, sampler::AbstractSampler, N::Integer; - progress::Union{Bool,UUIDs.UUID,Channel{Bool},Distributed.RemoteChannel{Channel{Bool}}}=PROGRESS[], + progress::Union{Bool,<:AbstractProgressKwarg}=PROGRESS[], progressname="Sampling", callback=nothing, num_warmup::Int=0, @@ -152,6 +152,14 @@ function mcmcsample( ArgumentError("number of warm-up samples exceeds the total number of samples") ) + # Initialise progress bar + if progress === true + progress = CreateNewProgressBar(progressname) + elseif progress === false + progress = NoLogging() + end + init_progress(progress) + # Determine how many samples to drop from `num_warmup` and the # main sampling process before we start saving samples. discard_from_warmup = min(num_warmup, discard_initial) @@ -161,10 +169,7 @@ function mcmcsample( start = time() local state - # Only create a new progress bar if progress is explicitly equal to true, i.e. - # it's not a UUID (the progress bar already exists), a channel (there's no need - # for a new progress bar), or false (no progress bar). - @ifwithprogresslogger (progress == true) name = progressname begin + try # Determine threshold values for progress logging # (one update per 0.5% of progress) threshold = Ntotal ÷ 200 @@ -175,10 +180,10 @@ function mcmcsample( # is constructed in the multi-chain method, i.e. if we don't do this # the progress bar shows the time elapsed since _all_ sampling began, # not since the current chain started. - if progress isa UUIDs.UUID + if progress isa ExistingProgressBar try bartrees = Logging.current_logger().loggers[1].logger.bartrees - bar = TerminalLoggers.findbar(bartrees, progress).data + bar = TerminalLoggers.findbar(bartrees, progress.uuid).data bar.tfirst = time() catch end @@ -202,13 +207,9 @@ function mcmcsample( # Start the progress bar. itotal = 1 if itotal >= next_update - @log_progress_dispatch progress progressname itotal / Ntotal + update_progress(progress, itotal / Ntotal, true) next_update = itotal + threshold end - if progress isa Channel{Bool} || - progress isa Distributed.RemoteChannel{Channel{Bool}} - put!(progress, true) - end # Discard initial samples. for j in 1:discard_initial @@ -222,7 +223,7 @@ function mcmcsample( # Update the progress bar. itotal += 1 if itotal >= next_update - @log_progress_dispatch progress progressname itotal / Ntotal + update_progress(progress, itotal / Ntotal, false) next_update = itotal + threshold end end @@ -248,7 +249,7 @@ function mcmcsample( # Update progress bar. itotal += 1 if itotal >= next_update - @log_progress_dispatch progress progressname itotal / Ntotal + update_progress(progress, itotal / Ntotal, false) next_update = itotal + threshold end end @@ -270,14 +271,17 @@ function mcmcsample( # Update the progress bar. itotal += 1 if itotal >= next_update - @log_progress_dispatch progress progressname itotal / Ntotal + update_progress(progress, itotal / Ntotal, true) next_update = itotal + threshold end - if progress isa Channel{Bool} || - progress isa Distributed.RemoteChannel{Channel{Bool}} - put!(progress, true) - end end + catch e + # If an error occurs, we still want to finish the progress bar. + finish_progress(progress) + rethrow(e) + finally + # Finish the progress bar. + finish_progress(progress) end # Get the sample stop time. @@ -473,7 +477,7 @@ function mcmcsample( if progress == :perchain # This is the 'overall' progress bar. We create a channel for each # chain to report back to when it finishes sampling. - progress_channel = Channel{Bool}() + progress_channel = Channel{Bool}(nchunks) # These are the per-chain progress bars. We generate `nchains` # independent UUIDs for each progress bar uuids = [UUIDs.uuid4() for _ in 1:nchains] From c6f9e785c63722ecdd61480f11bbe97caf6a2d85 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Tue, 1 Jul 2025 14:55:26 +0100 Subject: [PATCH 19/25] Refactor logging, throttle per-chain updates --- src/logging.jl | 66 ++++++++++++------------- src/sample.jl | 132 ++++++++++++++++++++++--------------------------- 2 files changed, 91 insertions(+), 107 deletions(-) diff --git a/src/logging.jl b/src/logging.jl index 29019da6..2dfea2fe 100644 --- a/src/logging.jl +++ b/src/logging.jl @@ -14,25 +14,19 @@ Create a new logger for progress logging. struct CreateNewProgressBar{S<:AbstractString} <: AbstractProgressKwarg name::S uuid::UUIDs.UUID - function CreateNewProgressBar(name::AbstractString) - return new{typeof{name}}(name, UUIDs.uuid4()) + return new{typeof(name)}(name, UUIDs.uuid4()) end end function init_progress(p::CreateNewProgressBar) - if hasprogresslevel(Logging.current_logger()) - ProgressLogging.@withprogress $(exprs...) - else - $with_progresslogger($Base.@__MODULE__, $Logging.current_logger()) do - $ProgressLogging.@withprogress $(exprs...) - end - end ProgressLogging.@logprogress p.name nothing _id = p.uuid end -function update_progress(p::CreateNewProgressBar, progress_frac, ::Bool) +function update_progress(p::CreateNewProgressBar, progress_frac) ProgressLogging.@logprogress p.name progress_frac _id = p.uuid end -finish_progress(::CreateNewProgressBar) = ProgressLogging.@logprogress "done" +function finish_progress(p::CreateNewProgressBar) + ProgressLogging.@logprogress p.name "done" _id = p.uuid +end """ NoLogging @@ -41,7 +35,7 @@ Do not log progress at all. """ struct NoLogging <: AbstractProgressKwarg end init_progress(::NoLogging) = nothing -update_progress(::NoLogging, ::Any, ::Bool) = nothing +update_progress(::NoLogging, ::Any) = nothing finish_progress(::NoLogging) = nothing """ @@ -57,8 +51,21 @@ struct ExistingProgressBar{S<:AbstractString} <: AbstractProgressKwarg name::S uuid::UUIDs.UUID end -init_progress(::ExistingProgressBar) = nothing -function update_progress(p::ExistingProgressBar, progress_frac, ::Bool) +function init_progress(p::ExistingProgressBar) + # Hacky code to reset the start timer if called from a multi-chain sampling + # process. We need this because the progress bar is constructed in the + # multi-chain method, i.e. if we don't do this the progress bar shows the + # time elapsed since _all_ sampling began, not since the current chain + # started. + try + bartrees = Logging.current_logger().loggers[1].logger.bartrees + bar = TerminalLoggers.findbar(bartrees, p.uuid).data + bar.tfirst = time() + catch + end + ProgressLogging.@logprogress p.name nothing _id = p.uuid +end +function update_progress(p::ExistingProgressBar, progress_frac) ProgressLogging.@logprogress p.name progress_frac _id = p.uuid end function finish_progress(p::ExistingProgressBar) @@ -71,39 +78,32 @@ end Use a `Channel` to log progress. This is used for 'reporting' progress back to the main thread or worker when using `progress=:overall` with MCMCThreads or MCMCDistributed. + +n_updates is the number of updates that each child thread is expected to report +back to the main thread. """ struct ChannelProgress{T<:Union{Channel{Bool},Distributed.RemoteChannel{Channel{Bool}}}} <: AbstractProgressKwarg channel::T + n_updates::Int end init_progress(::ChannelProgress) = nothing -function update_progress(p::ChannelProgress, ::Any, update_channel::Bool) - return update_channel && put!(p.channel, true) -end +update_progress(p::ChannelProgress, ::Any) = put!(p.channel, true) # Note: We don't want to `put!(p.channel, false)`, because that would stop the # channel from being used for further updates e.g. from other chains. finish_progress(::ChannelProgress) = nothing -# avoid creating a progress bar with @withprogress if progress logging is disabled -# and add a custom progress logger if the current logger does not seem to be able to handle -# progress logs -macro ifwithprogresslogger(cond, exprs...) +# Add a custom progress logger if the current logger does not seem to be able to handle +# progress logs. +macro withprogresslogger(expr) return esc( quote - if $cond - # Create a new logger - if $hasprogresslevel($Logging.current_logger()) - $ProgressLogging.@withprogress $(exprs...) - else - $with_progresslogger($Base.@__MODULE__, $Logging.current_logger()) do - $ProgressLogging.@withprogress $(exprs...) - end + if !($hasprogresslevel($Logging.current_logger())) + $with_progresslogger($Base.@__MODULE__, $Logging.current_logger()) do + $(expr) end else - # Don't create a new logger, either because progress logging - # was disabled, or because it's otherwise being manually - # managed. - $(exprs[end]) + $(expr) end end, ) diff --git a/src/sample.jl b/src/sample.jl index bb2b4f4b..26f51169 100644 --- a/src/sample.jl +++ b/src/sample.jl @@ -169,26 +169,13 @@ function mcmcsample( start = time() local state - try + @withprogresslogger begin # Determine threshold values for progress logging # (one update per 0.5% of progress) - threshold = Ntotal ÷ 200 + n_updates = progress isa ChannelProgress ? progress.n_updates : 200 + threshold = Ntotal ÷ n_updates next_update = threshold - # Slightly hacky code to reset the start timer if called from a - # multi-chain sampling process. We need this because the progress bar - # is constructed in the multi-chain method, i.e. if we don't do this - # the progress bar shows the time elapsed since _all_ sampling began, - # not since the current chain started. - if progress isa ExistingProgressBar - try - bartrees = Logging.current_logger().loggers[1].logger.bartrees - bar = TerminalLoggers.findbar(bartrees, progress.uuid).data - bar.tfirst = time() - catch - end - end - # Obtain the initial sample and state. sample, state = if num_warmup > 0 if initial_state === nothing @@ -207,7 +194,7 @@ function mcmcsample( # Start the progress bar. itotal = 1 if itotal >= next_update - update_progress(progress, itotal / Ntotal, true) + update_progress(progress, itotal / Ntotal) next_update = itotal + threshold end @@ -223,7 +210,7 @@ function mcmcsample( # Update the progress bar. itotal += 1 if itotal >= next_update - update_progress(progress, itotal / Ntotal, false) + update_progress(progress, itotal / Ntotal) next_update = itotal + threshold end end @@ -249,7 +236,7 @@ function mcmcsample( # Update progress bar. itotal += 1 if itotal >= next_update - update_progress(progress, itotal / Ntotal, false) + update_progress(progress, itotal / Ntotal) next_update = itotal + threshold end end @@ -271,16 +258,10 @@ function mcmcsample( # Update the progress bar. itotal += 1 if itotal >= next_update - update_progress(progress, itotal / Ntotal, true) + update_progress(progress, itotal / Ntotal) next_update = itotal + threshold end end - catch e - # If an error occurs, we still want to finish the progress bar. - finish_progress(progress) - rethrow(e) - finally - # Finish the progress bar. finish_progress(progress) end @@ -323,6 +304,14 @@ function mcmcsample( num_warmup >= 0 || throw(ArgumentError("number of warm-up samples must be non-negative")) + # Initialise progress bar + if progress === true + progress = CreateNewProgressBar(progressname) + elseif progress === false + progress = NoLogging() + end + init_progress(progress) + # Determine how many samples to drop from `num_warmup` and the # main sampling process before we start saving samples. discard_from_warmup = min(num_warmup, discard_initial) @@ -332,7 +321,7 @@ function mcmcsample( start = time() local state - @ifwithprogresslogger progress name = progressname begin + @withprogresslogger begin # Obtain the initial sample and state. sample, state = if num_warmup > 0 if initial_state === nothing @@ -398,6 +387,7 @@ function mcmcsample( end # Get the sample stop time. + finish_progress(progress) stop = time() duration = stop - start stats = SamplingStats(start, stop, duration) @@ -473,35 +463,49 @@ function mcmcsample( # Set up a chains vector. chains = Vector{Any}(undef, nchains) - @ifwithprogresslogger (progress != :none) name = progressname begin + @withprogresslogger begin if progress == :perchain # This is the 'overall' progress bar. We create a channel for each # chain to report back to when it finishes sampling. progress_channel = Channel{Bool}(nchunks) + overall_progress_bar = CreateNewProgressBar(progressname) + init_progress(overall_progress_bar) # These are the per-chain progress bars. We generate `nchains` # independent UUIDs for each progress bar - uuids = [UUIDs.uuid4() for _ in 1:nchains] - progress_names = ["Chain $i/$nchains" for i in 1:nchains] + child_progresses = [ + ExistingProgressBar("Chain $i/$nchains", UUIDs.uuid4()) for i in 1:nchains + ] # Start the per-chain progress bars (but in reverse order, because # ProgressLogging prints from the bottom up, and we want chain 1 to # show up at the top) - for (progress_name, uuid) in reverse(collect(zip(progress_names, uuids))) - ProgressLogging.@logprogress name = progress_name nothing _id = uuid + for child_progress in reverse(child_progresses) + init_progress(child_progress) end + updates_per_chain = nothing elseif progress == :overall # Just a single progress bar for the entire sampling, but instead # of tracking each chain as it comes in, we track each sample as it # comes in. This allows us to have more granular progress updates. progress_channel = Channel{Bool}(nchains) + overall_progress_bar = CreateNewProgressBar(progressname) + # If we have many chains and many samples, we don't want to force + # each chain to report back to the main thread for each sample, as + # this would cause serious performance issues due to lock conflicts. + # In the overall progress bar we only expect 200 updates (i.e., one + # update per 0.5%). To avoid possible throttling issues we ask for + # twice the amount needed per chain, which doesn't cause a real + # performance hit. + updates_per_chain = max(1, 400 ÷ nchains) end Distributed.@sync begin if progress != :none # This task updates the progress bar Distributed.@async begin + # Total number of updates (across all chains) + Ntotal = progress == :overall ? nchains * updates_per_chain : nchains # Determine threshold values for progress logging # (one update per 0.5% of progress) - Ntotal = progress == :overall ? nchains * N : nchains threshold = Ntotal ÷ 200 next_update = threshold @@ -509,10 +513,11 @@ function mcmcsample( while take!(progress_channel) itotal += 1 if itotal >= next_update - ProgressLogging.@logprogress itotal / Ntotal + update_progress(overall_progress_bar, itotal / Ntotal) next_update = itotal + threshold end end + finish_progress(overall_progress_bar) end end @@ -535,19 +540,12 @@ function mcmcsample( Random.seed!(_rng, seeds[chainidx]) # Determine how to monitor progress for the child chains. - child_progress, child_progressname = if progress == :none - # No need to create a progress bar - false, "" + child_progress = if progress == :none + false elseif progress == :overall - # No need to create a new progress bar, but we need to - # pass the channel to the child so that it can log when - # it has finished obtaining each sample. - progress_channel, "" + ChannelProgress(progress_channel, updates_per_chain) elseif progress == :perchain - # We need to specify both the ID of the progress bar for - # the child to update, and we also specify the name to use - # for the progress bar. - uuids[chainidx], progress_names[chainidx] + child_progresses[chainidx] # <- isa ExistingProgressBar end # Sample a chain and save it to the vector. @@ -557,7 +555,6 @@ function mcmcsample( _sampler, N; progress=child_progress, - progressname=child_progressname, initial_params=if initial_params === nothing nothing else @@ -576,7 +573,7 @@ function mcmcsample( # Tell the 'main' progress bar that this chain is done. put!(progress_channel, true) # Conclude the per-chain progress bar. - ProgressLogging.@logprogress progress_names[chainidx] "done" _id = uuids[chainidx] + finish_progress(child_progresses[chainidx]) end # Note that if progress == :overall, we don't need to do anything # because progress on that bar is triggered by @@ -590,8 +587,8 @@ function mcmcsample( # is done, or if an error occurs). put!(progress_channel, false) # Additionally stop the per-chain progress bars - for (progress_name, uuid) in zip(progress_names, uuids) - ProgressLogging.@logprogress progress_name "done" _id = uuid + for child_progress in child_progresses + finish_progress(child_progress) end elseif progress == :overall # Stop updating the main progress bar (either if sampling @@ -659,7 +656,7 @@ function mcmcsample( pool = Distributed.CachingPool(Distributed.workers()) local chains - @ifwithprogresslogger (progress != :none) name = progressname begin + @withprogresslogger begin # Set up progress logging. if progress == :overall # Just a single progress bar for the entire sampling, but instead @@ -667,11 +664,15 @@ function mcmcsample( # comes in. This allows us to have more granular progress updates. chan = Channel{Bool}(Distributed.nworkers()) progress_channel = Distributed.RemoteChannel(() -> chan) - child_progresses = [progress_channel for _ in 1:nchains] - child_progressnames = ["" for _ in 1:nchains] + overall_progress_bar = CreateNewProgressBar(progressname) + init_progress(overall_progress_bar) + # See MCMCThreads method for the rationale behind updates_per_chain. + updates_per_chain = max(1, 400 ÷ nchains) + child_progresses = [ + ChannelProgress(progress_channel, updates_per_chain) for _ in 1:nchains + ] elseif progress == :none child_progresses = [false for _ in 1:nchains] - child_progressnames = ["" for _ in 1:nchains] end Distributed.@sync begin @@ -680,7 +681,7 @@ function mcmcsample( Distributed.@async begin # Determine threshold values for progress logging # (one update per 0.5% of progress) - Ntotal = nchains * N + Ntotal = nchains * updates_per_chain threshold = Ntotal ÷ 200 next_update = threshold @@ -688,21 +689,18 @@ function mcmcsample( while take!(progress_channel) itotal += 1 if itotal >= next_update - ProgressLogging.@logprogress itotal / Ntotal + update_progress(overall_progress_bar, itotal / Ntotal) next_update = itotal + threshold end end + finish_progress(overall_progress_bar) end end Distributed.@async begin try function sample_chain( - seed, - initial_params, - initial_state, - child_progress, - child_progressname, + seed, initial_params, initial_state, child_progress ) # Seed a new random number generator with the pre-made seed. Random.seed!(rng, seed) @@ -714,24 +712,11 @@ function mcmcsample( sampler, N; progress=child_progress, - progressname=child_progressname, initial_params=initial_params, initial_state=initial_state, kwargs..., ) - # Update the progress bars. Note that the case of - # progress = :overall doesn't need to be handled here - # (for similar reasons to the MCMCThreads method - # above). - if progress == :perchain - # Tell the 'main' progress bar that this chain is done. - put!(progress_channel, true) - # Conclude the per-chain progress bar. - ProgressLogging.@logprogress child_progressname "done" _id = - child_progress - end - # Return the new chain. return chain end @@ -742,7 +727,6 @@ function mcmcsample( _initial_params, _initial_state, child_progresses, - child_progressnames, ) finally if progress == :overall From d9c2e867acd3f65b367280d376230287b656fe4a Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Tue, 1 Jul 2025 15:07:28 +0100 Subject: [PATCH 20/25] Improve comment --- src/sample.jl | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/sample.jl b/src/sample.jl index 26f51169..1b638c8b 100644 --- a/src/sample.jl +++ b/src/sample.jl @@ -465,9 +465,13 @@ function mcmcsample( @withprogresslogger begin if progress == :perchain - # This is the 'overall' progress bar. We create a channel for each - # chain to report back to when it finishes sampling. + # Create a channel for each chain to report back to when it + # finishes sampling. progress_channel = Channel{Bool}(nchunks) + # This is the 'overall' progress bar which tracks the number of + # chains that have completed. Note that this progress bar is backed + # by a channel, but it is not itself a ChannelProgress (because + # ChannelProgress doesn't come with a progress bar). overall_progress_bar = CreateNewProgressBar(progressname) init_progress(overall_progress_bar) # These are the per-chain progress bars. We generate `nchains` From f8a8b642b560e7f41582f995cef81e64afa335e5 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Tue, 1 Jul 2025 15:13:10 +0100 Subject: [PATCH 21/25] add warning --- docs/src/api.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/docs/src/api.md b/docs/src/api.md index daa75eab..8978e03a 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -115,6 +115,9 @@ Thus, for example, if you _always_ want to use `:overall`, you can call `Abstrac Multiple-chain sampling using `MCMCDistributed` behaves the same as `MCMCThreads`, except that `:perchain` is not (yet?) implemented. So, `true` always corresponds to `:overall`, and `false` corresponds to `:none`. +!!! warning "Do not override the `progress` keyword argument" + If you are implementing your own methods for `sample(...)`, you should make sure to not override the `progress` keyword argument if you want progress logging in multi-chain sampling to work correctly, as the multi-chain `sample()` call makes sure to specifically pass custom values of `progress` to the single-chain calls. + ## Chains The `chain_type` keyword argument allows to set the type of the returned chain. A common From 64b0bfbaee2d2f568d2a79c5b478b75af644c804 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Tue, 1 Jul 2025 15:21:34 +0100 Subject: [PATCH 22/25] fix convergence sampling --- src/sample.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/sample.jl b/src/sample.jl index 1b638c8b..35061765 100644 --- a/src/sample.jl +++ b/src/sample.jl @@ -158,7 +158,6 @@ function mcmcsample( elseif progress === false progress = NoLogging() end - init_progress(progress) # Determine how many samples to drop from `num_warmup` and the # main sampling process before we start saving samples. @@ -170,6 +169,7 @@ function mcmcsample( local state @withprogresslogger begin + init_progress(progress) # Determine threshold values for progress logging # (one update per 0.5% of progress) n_updates = progress isa ChannelProgress ? progress.n_updates : 200 @@ -310,7 +310,6 @@ function mcmcsample( elseif progress === false progress = NoLogging() end - init_progress(progress) # Determine how many samples to drop from `num_warmup` and the # main sampling process before we start saving samples. @@ -322,6 +321,7 @@ function mcmcsample( local state @withprogresslogger begin + init_progress(progress) # Obtain the initial sample and state. sample, state = if num_warmup > 0 if initial_state === nothing @@ -384,10 +384,10 @@ function mcmcsample( # Increment iteration counter. i += 1 end + finish_progress(progress) end # Get the sample stop time. - finish_progress(progress) stop = time() duration = stop - start stats = SamplingStats(start, stop, duration) From 27569b31a533c69f4c5c96dc11c3b90ef225a5fe Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Tue, 1 Jul 2025 16:04:33 +0100 Subject: [PATCH 23/25] Don't use integer division --- src/logging.jl | 1 + src/sample.jl | 23 ++++++++++++----------- 2 files changed, 13 insertions(+), 11 deletions(-) diff --git a/src/logging.jl b/src/logging.jl index 2dfea2fe..b8e25bd0 100644 --- a/src/logging.jl +++ b/src/logging.jl @@ -23,6 +23,7 @@ function init_progress(p::CreateNewProgressBar) end function update_progress(p::CreateNewProgressBar, progress_frac) ProgressLogging.@logprogress p.name progress_frac _id = p.uuid + @show progress_frac end function finish_progress(p::CreateNewProgressBar) ProgressLogging.@logprogress p.name "done" _id = p.uuid diff --git a/src/sample.jl b/src/sample.jl index 35061765..02afcd21 100644 --- a/src/sample.jl +++ b/src/sample.jl @@ -170,10 +170,11 @@ function mcmcsample( @withprogresslogger begin init_progress(progress) - # Determine threshold values for progress logging - # (one update per 0.5% of progress) + # Determine threshold values for progress logging (by default, one + # update per 0.5% of progress, unless this has been passed in + # explicitly) n_updates = progress isa ChannelProgress ? progress.n_updates : 200 - threshold = Ntotal ÷ n_updates + threshold = Ntotal / n_updates next_update = threshold # Obtain the initial sample and state. @@ -195,7 +196,7 @@ function mcmcsample( itotal = 1 if itotal >= next_update update_progress(progress, itotal / Ntotal) - next_update = itotal + threshold + next_update += threshold end # Discard initial samples. @@ -211,7 +212,7 @@ function mcmcsample( itotal += 1 if itotal >= next_update update_progress(progress, itotal / Ntotal) - next_update = itotal + threshold + next_update += threshold end end @@ -237,7 +238,7 @@ function mcmcsample( itotal += 1 if itotal >= next_update update_progress(progress, itotal / Ntotal) - next_update = itotal + threshold + next_update += threshold end end @@ -259,7 +260,7 @@ function mcmcsample( itotal += 1 if itotal >= next_update update_progress(progress, itotal / Ntotal) - next_update = itotal + threshold + next_update += threshold end end finish_progress(progress) @@ -510,7 +511,7 @@ function mcmcsample( Ntotal = progress == :overall ? nchains * updates_per_chain : nchains # Determine threshold values for progress logging # (one update per 0.5% of progress) - threshold = Ntotal ÷ 200 + threshold = Ntotal / 200 next_update = threshold itotal = 0 @@ -518,7 +519,7 @@ function mcmcsample( itotal += 1 if itotal >= next_update update_progress(overall_progress_bar, itotal / Ntotal) - next_update = itotal + threshold + next_update += threshold end end finish_progress(overall_progress_bar) @@ -686,7 +687,7 @@ function mcmcsample( # Determine threshold values for progress logging # (one update per 0.5% of progress) Ntotal = nchains * updates_per_chain - threshold = Ntotal ÷ 200 + threshold = Ntotal / 200 next_update = threshold itotal = 0 @@ -694,7 +695,7 @@ function mcmcsample( itotal += 1 if itotal >= next_update update_progress(overall_progress_bar, itotal / Ntotal) - next_update = itotal + threshold + next_update += threshold end end finish_progress(overall_progress_bar) From 4cd647afe42ec9759613763d5285cac0e62efe5e Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Tue, 1 Jul 2025 16:06:37 +0100 Subject: [PATCH 24/25] remove extra show --- src/logging.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/logging.jl b/src/logging.jl index b8e25bd0..2dfea2fe 100644 --- a/src/logging.jl +++ b/src/logging.jl @@ -23,7 +23,6 @@ function init_progress(p::CreateNewProgressBar) end function update_progress(p::CreateNewProgressBar, progress_frac) ProgressLogging.@logprogress p.name progress_frac _id = p.uuid - @show progress_frac end function finish_progress(p::CreateNewProgressBar) ProgressLogging.@logprogress p.name "done" _id = p.uuid From 9f8970d454904963a2cc2ec477616fe87b276d17 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Tue, 1 Jul 2025 16:15:12 +0100 Subject: [PATCH 25/25] Rename withprogresslogger macro --- src/logging.jl | 2 +- src/sample.jl | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/logging.jl b/src/logging.jl index 2dfea2fe..9eebc0c0 100644 --- a/src/logging.jl +++ b/src/logging.jl @@ -95,7 +95,7 @@ finish_progress(::ChannelProgress) = nothing # Add a custom progress logger if the current logger does not seem to be able to handle # progress logs. -macro withprogresslogger(expr) +macro maybewithricherlogger(expr) return esc( quote if !($hasprogresslevel($Logging.current_logger())) diff --git a/src/sample.jl b/src/sample.jl index 02afcd21..220cf38e 100644 --- a/src/sample.jl +++ b/src/sample.jl @@ -168,7 +168,7 @@ function mcmcsample( start = time() local state - @withprogresslogger begin + @maybewithricherlogger begin init_progress(progress) # Determine threshold values for progress logging (by default, one # update per 0.5% of progress, unless this has been passed in @@ -321,7 +321,7 @@ function mcmcsample( start = time() local state - @withprogresslogger begin + @maybewithricherlogger begin init_progress(progress) # Obtain the initial sample and state. sample, state = if num_warmup > 0 @@ -464,7 +464,7 @@ function mcmcsample( # Set up a chains vector. chains = Vector{Any}(undef, nchains) - @withprogresslogger begin + @maybewithricherlogger begin if progress == :perchain # Create a channel for each chain to report back to when it # finishes sampling. @@ -661,7 +661,7 @@ function mcmcsample( pool = Distributed.CachingPool(Distributed.workers()) local chains - @withprogresslogger begin + @maybewithricherlogger begin # Set up progress logging. if progress == :overall # Just a single progress bar for the entire sampling, but instead