From b5ea8028df97e619987dfa0179fbbabae670e9cd Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Thu, 26 Jun 2025 23:34:09 +0100 Subject: [PATCH 1/5] [wip] fix parallel sampling --- src/logging.jl | 2 +- src/sample.jl | 43 ++++++++++++++++++++++++++++++++----------- 2 files changed, 33 insertions(+), 12 deletions(-) diff --git a/src/logging.jl b/src/logging.jl index 04c41187..7abeec76 100644 --- a/src/logging.jl +++ b/src/logging.jl @@ -4,7 +4,7 @@ macro ifwithprogresslogger(progress, exprs...) return esc( quote - if $progress + if $progress == true if $hasprogresslevel($Logging.current_logger()) $ProgressLogging.@withprogress $(exprs...) else diff --git a/src/sample.jl b/src/sample.jl index 01a2006a..74ac71f8 100644 --- a/src/sample.jl +++ b/src/sample.jl @@ -144,7 +144,7 @@ function mcmcsample( @ifwithprogresslogger progress name = progressname begin # Determine threshold values for progress logging # (one update per 0.5% of progress) - if progress + if (progress == true || progress === nothing) threshold = Ntotal ÷ 200 next_update = threshold end @@ -166,8 +166,12 @@ function mcmcsample( # Update the progress bar. itotal = 1 - if progress && itotal >= next_update - ProgressLogging.@logprogress itotal / Ntotal + if !(progress == false) && itotal >= next_update + if progress == true + ProgressLogging.@logprogress itotal / Ntotal + else + ProgressLogging.@logprogress itotal / Ntotal _id = "hello" + end next_update = itotal + threshold end @@ -181,8 +185,12 @@ function mcmcsample( end # Update the progress bar. - if progress && (itotal += 1) >= next_update - ProgressLogging.@logprogress itotal / Ntotal + if !(progress == false) && (itotal += 1) >= next_update + if progress == true + ProgressLogging.@logprogress itotal / Ntotal + else + ProgressLogging.@logprogress itotal / Ntotal _id = "hello" + end next_update = itotal + threshold end end @@ -206,8 +214,12 @@ function mcmcsample( end # Update progress bar. - if progress && (itotal += 1) >= next_update - ProgressLogging.@logprogress itotal / Ntotal + if !(progress == false) && (itotal += 1) >= next_update + if progress == true + ProgressLogging.@logprogress itotal / Ntotal + else + ProgressLogging.@logprogress itotal / Ntotal _id = "hello" + end next_update = itotal + threshold end end @@ -227,8 +239,12 @@ function mcmcsample( samples = save!!(samples, sample, i, model, sampler, N; kwargs...) # Update the progress bar. - if progress && (itotal += 1) >= next_update - ProgressLogging.@logprogress itotal / Ntotal + if !(progress == false) && (itotal += 1) >= next_update + if progress == true + ProgressLogging.@logprogress itotal / Ntotal + else + ProgressLogging.@logprogress itotal / Ntotal _id = "hello" + end next_update = itotal + threshold end end @@ -456,12 +472,17 @@ function mcmcsample( Random.seed!(_rng, seeds[chainidx]) # Sample a chain and save it to the vector. - chains[chainidx] = StatsBase.sample( + child_progress = if progress == false + false + else + nothing + end + @ifwithprogresslogger progress chains[chainidx] = StatsBase.sample( _rng, _model, _sampler, N; - progress=false, + progress=child_progress, initial_params=if initial_params === nothing nothing else From f5a6535f8622d2d4cef7f4c1cfe88f79169b4ae1 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Fri, 27 Jun 2025 12:40:55 +0100 Subject: [PATCH 2/5] Parallel sampling with ProgressLogging --- Project.toml | 2 ++ src/sample.jl | 49 ++++++++++++++++++++++++++++++++++++++++++------- 2 files changed, 44 insertions(+), 7 deletions(-) diff --git a/Project.toml b/Project.toml index 77de8700..213f92a1 100644 --- a/Project.toml +++ b/Project.toml @@ -18,6 +18,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" TerminalLoggers = "5d786b92-1e48-4d6f-9151-6b4477ca9bed" Transducers = "28d57a85-8fef-5791-bfe6-a80928e7c999" +UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" [compat] BangBang = "0.3.19, 0.4" @@ -29,6 +30,7 @@ ProgressLogging = "0.1" StatsBase = "0.32, 0.33, 0.34" TerminalLoggers = "0.1" Transducers = "0.4.30" +UUIDs = "1.11.0" julia = "1.6" [extras] diff --git a/src/sample.jl b/src/sample.jl index 74ac71f8..d90a470a 100644 --- a/src/sample.jl +++ b/src/sample.jl @@ -1,3 +1,5 @@ +using UUIDs: uuid4 + # Default implementations of `sample`. const PROGRESS = Ref(true) @@ -144,11 +146,22 @@ function mcmcsample( @ifwithprogresslogger progress name = progressname begin # Determine threshold values for progress logging # (one update per 0.5% of progress) - if (progress == true || progress === nothing) + if !(progress == false) threshold = Ntotal ÷ 200 next_update = threshold end + # Ugly hacky code to reset the start timer if called from a multi-chain + # sampling process + if progress isa ProgressLogging.Progress + try + bartrees = Logging.current_logger().loggers[1].logger.bartrees + bar = TerminalLoggers.findbar(bartrees, progress.id).data + bar.tfirst = time() + catch + end + end + # Obtain the initial sample and state. sample, state = if num_warmup > 0 if initial_state === nothing @@ -170,7 +183,8 @@ function mcmcsample( if progress == true ProgressLogging.@logprogress itotal / Ntotal else - ProgressLogging.@logprogress itotal / Ntotal _id = "hello" + ProgressLogging.@logprogress name = progressname itotal / Ntotal _id = + progress.id end next_update = itotal + threshold end @@ -189,7 +203,8 @@ function mcmcsample( if progress == true ProgressLogging.@logprogress itotal / Ntotal else - ProgressLogging.@logprogress itotal / Ntotal _id = "hello" + ProgressLogging.@logprogress name = progressname itotal / Ntotal _id = + progress.id end next_update = itotal + threshold end @@ -218,7 +233,8 @@ function mcmcsample( if progress == true ProgressLogging.@logprogress itotal / Ntotal else - ProgressLogging.@logprogress itotal / Ntotal _id = "hello" + ProgressLogging.@logprogress name = progressname itotal / Ntotal _id = + progress.id end next_update = itotal + threshold end @@ -243,7 +259,8 @@ function mcmcsample( if progress == true ProgressLogging.@logprogress itotal / Ntotal else - ProgressLogging.@logprogress itotal / Ntotal _id = "hello" + ProgressLogging.@logprogress name = progressname itotal / Ntotal _id = + progress.id end next_update = itotal + threshold end @@ -432,6 +449,18 @@ function mcmcsample( if progress channel = Channel{Bool}(length(interval)) end + # Generate nchains independent UUIDs for each progress bar + uuids = [uuid4() for _ in 1:nchains] + # Start the progress bars (but in reverse order, because + # ProgressLogging prints from the bottom up, and we want chain 1 to + # show up at the top) + # TODO: This has an unintended effect that the 'time' field in the + # progress bar shows the total time since the beginning of sampling, + # even if the specific chain doesn't start sampling until later on. + for (i, uuid) in enumerate(reverse(uuids)) + ProgressLogging.@logprogress name = "Chain $(nchains-i+1)/$nchains" nothing _id = + uuid + end Distributed.@sync begin if progress @@ -472,17 +501,21 @@ function mcmcsample( Random.seed!(_rng, seeds[chainidx]) # Sample a chain and save it to the vector. + child_progressname = "Chain $chainidx/$nchains" child_progress = if progress == false false else - nothing + ProgressLogging.Progress( + uuids[chainidx]; name=child_progressname + ) end - @ifwithprogresslogger progress chains[chainidx] = StatsBase.sample( + chains[chainidx] = StatsBase.sample( _rng, _model, _sampler, N; progress=child_progress, + progressname=child_progressname, initial_params=if initial_params === nothing nothing else @@ -496,6 +529,8 @@ function mcmcsample( kwargs..., ) + ProgressLogging.@logprogress name = child_progressname "done" _id = uuids[chainidx] + # Update the progress bar. progress && put!(channel, true) end From 3fdcbcc4f6b2e7c306728b61c9314f0df6c6f1f1 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Fri, 27 Jun 2025 12:46:22 +0100 Subject: [PATCH 3/5] destroy per-chain progress bars if an error occurs --- src/sample.jl | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/sample.jl b/src/sample.jl index d90a470a..719ea7de 100644 --- a/src/sample.jl +++ b/src/sample.jl @@ -454,9 +454,6 @@ function mcmcsample( # Start the progress bars (but in reverse order, because # ProgressLogging prints from the bottom up, and we want chain 1 to # show up at the top) - # TODO: This has an unintended effect that the 'time' field in the - # progress bar shows the total time since the beginning of sampling, - # even if the specific chain doesn't start sampling until later on. for (i, uuid) in enumerate(reverse(uuids)) ProgressLogging.@logprogress name = "Chain $(nchains-i+1)/$nchains" nothing _id = uuid @@ -536,8 +533,13 @@ function mcmcsample( end end finally - # Stop updating the progress bar. + # Stop updating the progress bars (either if sampling is done, or if + # an error occurs). progress && put!(channel, false) + for (i, uuid) in enumerate(reverse(uuids)) + ProgressLogging.@logprogress name = "Chain $(nchains-i+1)/$nchains" "done" _id = + uuid + end end end end From ab3cf264c093e4e0120f2d27e5f2b52dfce75926 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Fri, 27 Jun 2025 12:46:54 +0100 Subject: [PATCH 4/5] add a todo --- src/sample.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/sample.jl b/src/sample.jl index 719ea7de..10da7e6f 100644 --- a/src/sample.jl +++ b/src/sample.jl @@ -153,6 +153,7 @@ function mcmcsample( # Ugly hacky code to reset the start timer if called from a multi-chain # sampling process + # TODO: How to make this better? if progress isa ProgressLogging.Progress try bartrees = Logging.current_logger().loggers[1].logger.bartrees From dc43291eb46b1ad3de8b5312a4346c5d2e03ef86 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Fri, 27 Jun 2025 14:58:18 +0100 Subject: [PATCH 5/5] report proportion of total samples instead --- src/sample.jl | 74 +++++++++++++++++++++++---------------------------- 1 file changed, 34 insertions(+), 40 deletions(-) diff --git a/src/sample.jl b/src/sample.jl index 10da7e6f..259f9df4 100644 --- a/src/sample.jl +++ b/src/sample.jl @@ -121,6 +121,7 @@ function mcmcsample( thinning=1, chain_type::Type=Any, initial_state=nothing, + _progress_channel=nothing, kwargs..., ) # Check the number of requested samples. @@ -146,10 +147,8 @@ function mcmcsample( @ifwithprogresslogger progress name = progressname begin # Determine threshold values for progress logging # (one update per 0.5% of progress) - if !(progress == false) - threshold = Ntotal ÷ 200 - next_update = threshold - end + threshold = Ntotal ÷ 200 + next_update = threshold # Ugly hacky code to reset the start timer if called from a multi-chain # sampling process @@ -180,10 +179,10 @@ function mcmcsample( # Update the progress bar. itotal = 1 - if !(progress == false) && itotal >= next_update + if itotal >= next_update if progress == true ProgressLogging.@logprogress itotal / Ntotal - else + elseif progress isa ProgressLogging.Progress ProgressLogging.@logprogress name = progressname itotal / Ntotal _id = progress.id end @@ -200,10 +199,11 @@ function mcmcsample( end # Update the progress bar. - if !(progress == false) && (itotal += 1) >= next_update + _progress_channel !== nothing && put!(_progress_channel, true) + if (itotal += 1) >= next_update if progress == true ProgressLogging.@logprogress itotal / Ntotal - else + elseif progress isa ProgressLogging.Progress ProgressLogging.@logprogress name = progressname itotal / Ntotal _id = progress.id end @@ -230,10 +230,10 @@ function mcmcsample( end # Update progress bar. - if !(progress == false) && (itotal += 1) >= next_update + if (itotal += 1) >= next_update if progress == true ProgressLogging.@logprogress itotal / Ntotal - else + elseif progress isa ProgressLogging.Progress ProgressLogging.@logprogress name = progressname itotal / Ntotal _id = progress.id end @@ -256,10 +256,11 @@ function mcmcsample( samples = save!!(samples, sample, i, model, sampler, N; kwargs...) # Update the progress bar. - if !(progress == false) && (itotal += 1) >= next_update + _progress_channel !== nothing && put!(_progress_channel, true) + if (itotal += 1) >= next_update if progress == true ProgressLogging.@logprogress itotal / Ntotal - else + elseif progress isa ProgressLogging.Progress ProgressLogging.@logprogress name = progressname itotal / Ntotal _id = progress.id end @@ -451,14 +452,14 @@ function mcmcsample( channel = Channel{Bool}(length(interval)) end # Generate nchains independent UUIDs for each progress bar - uuids = [uuid4() for _ in 1:nchains] + # uuids = [uuid4() for _ in 1:nchains] # Start the progress bars (but in reverse order, because # ProgressLogging prints from the bottom up, and we want chain 1 to # show up at the top) - for (i, uuid) in enumerate(reverse(uuids)) - ProgressLogging.@logprogress name = "Chain $(nchains-i+1)/$nchains" nothing _id = - uuid - end + # for (i, uuid) in enumerate(reverse(uuids)) + # ProgressLogging.@logprogress name = "Chain $(nchains-i+1)/$nchains" nothing _id = + # uuid + # end Distributed.@sync begin if progress @@ -466,15 +467,16 @@ function mcmcsample( Distributed.@async begin # Determine threshold values for progress logging # (one update per 0.5% of progress) - threshold = nchains ÷ 200 - nextprogresschains = threshold + nprogupdates = nchains * N + threshold = nprogupdates ÷ 200 + counter = 0 + next_update = threshold - progresschains = 0 while take!(channel) - progresschains += 1 - if progresschains >= nextprogresschains - ProgressLogging.@logprogress progresschains / nchains - nextprogresschains = progresschains + threshold + counter += 1 + if counter >= next_update + ProgressLogging.@logprogress counter / nprogupdates + next_update = next_update + threshold end end end @@ -499,21 +501,12 @@ function mcmcsample( Random.seed!(_rng, seeds[chainidx]) # Sample a chain and save it to the vector. - child_progressname = "Chain $chainidx/$nchains" - child_progress = if progress == false - false - else - ProgressLogging.Progress( - uuids[chainidx]; name=child_progressname - ) - end chains[chainidx] = StatsBase.sample( _rng, _model, _sampler, N; - progress=child_progress, - progressname=child_progressname, + progress=false, initial_params=if initial_params === nothing nothing else @@ -524,11 +517,12 @@ function mcmcsample( else initial_state[chainidx] end, + _progress_channel=channel, kwargs..., ) - ProgressLogging.@logprogress name = child_progressname "done" _id = uuids[chainidx] - + # ProgressLogging.@logprogress name = child_progressname "done" _id = uuids[chainidx] + # # Update the progress bar. progress && put!(channel, true) end @@ -537,10 +531,10 @@ function mcmcsample( # Stop updating the progress bars (either if sampling is done, or if # an error occurs). progress && put!(channel, false) - for (i, uuid) in enumerate(reverse(uuids)) - ProgressLogging.@logprogress name = "Chain $(nchains-i+1)/$nchains" "done" _id = - uuid - end + # for (i, uuid) in enumerate(reverse(uuids)) + # ProgressLogging.@logprogress name = "Chain $(nchains-i+1)/$nchains" "done" _id = + # uuid + # end end end end