Skip to content

Progress bars when sampling multiple chains #168

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 25 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
b3434ac
[wip] fix parallel sampling
penelopeysm Jun 26, 2025
15250c7
Parallel sampling with ProgressLogging
penelopeysm Jun 27, 2025
367718b
destroy per-chain progress bars if an error occurs
penelopeysm Jun 27, 2025
60134f8
add a todo
penelopeysm Jun 27, 2025
e0ae513
Fix implementation
penelopeysm Jun 30, 2025
6b514e4
Bump minor version
penelopeysm Jun 30, 2025
bbda3c8
Add `setmaxchainsprogress!`
penelopeysm Jun 30, 2025
a9e5306
Don't duplicate macro
penelopeysm Jun 30, 2025
a03692d
:overall works with MCMCDistributed now
penelopeysm Jun 30, 2025
838db60
Give up on :perchain for MCMCDistributed
penelopeysm Jun 30, 2025
6b59b21
Fix comments
penelopeysm Jun 30, 2025
b340ebc
Remove dead code
penelopeysm Jun 30, 2025
1195503
Undelete some not-actually-dead code
penelopeysm Jun 30, 2025
594483f
Broaden UUIDs compat so that it works on older Julia versions
penelopeysm Jun 30, 2025
7def4b4
Explain progress logging in docs
penelopeysm Jun 30, 2025
022678e
Remove dead code
penelopeysm Jul 1, 2025
5b2577f
Fix channel buffering for MCMCThreads
penelopeysm Jul 1, 2025
cefafb0
Attempt to use proper types for logging
penelopeysm Jul 1, 2025
c6f9e78
Refactor logging, throttle per-chain updates
penelopeysm Jul 1, 2025
d9c2e86
Improve comment
penelopeysm Jul 1, 2025
f8a8b64
add warning
penelopeysm Jul 1, 2025
64b0bfb
fix convergence sampling
penelopeysm Jul 1, 2025
27569b3
Don't use integer division
penelopeysm Jul 1, 2025
4cd647a
remove extra show
penelopeysm Jul 1, 2025
9f8970d
Rename withprogresslogger macro
penelopeysm Jul 1, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ uuid = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
keywords = ["markov chain monte carlo", "probabilistic programming"]
license = "MIT"
desc = "A lightweight interface for common MCMC methods."
version = "5.6.3"
version = "5.7.0"

[deps]
BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66"
Expand All @@ -18,6 +18,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
TerminalLoggers = "5d786b92-1e48-4d6f-9151-6b4477ca9bed"
Transducers = "28d57a85-8fef-5791-bfe6-a80928e7c999"
UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"

[compat]
BangBang = "0.3.19, 0.4"
Expand All @@ -29,6 +30,7 @@ ProgressLogging = "0.1"
StatsBase = "0.32, 0.33, 0.34"
TerminalLoggers = "0.1"
Transducers = "0.4.30"
UUIDs = "<0.0.1, 1"
julia = "1.6"

[extras]
Expand Down
26 changes: 24 additions & 2 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ AbstractMCMC.MCMCSerial
## Common keyword arguments

Common keyword arguments for regular and parallel sampling are:
- `progress` (default: `AbstractMCMC.PROGRESS[]` which is `true` initially): toggles progress logging
- `progress` (default: `AbstractMCMC.PROGRESS[]` which is `true` initially): toggles progress logging. See the section on [Progress logging](#progress-logging) below for more details.
- `chain_type` (default: `Any`): determines the type of the returned chain
- `callback` (default: `nothing`): if `callback !== nothing`, then
`callback(rng, model, sampler, sample, iteration)` is called after every sampling step,
Expand All @@ -90,12 +90,34 @@ However, multiple packages such as [EllipticalSliceSampling.jl](https://github.c
To ensure that sampling multiple chains "just works" when sampling of a single chain is implemented, [we decided to support `initial_params` in the default implementations of the ensemble methods](https://github.com/TuringLang/AbstractMCMC.jl/pull/94):
- `initial_params` (default: `nothing`): if `initial_params isa AbstractArray`, then the `i`th element of `initial_params` is used as initial parameters of the `i`th chain. If one wants to use the same initial parameters `x` for every chain, one can specify e.g. `initial_params = FillArrays.Fill(x, N)`.

Progress logging can be enabled and disabled globally with `AbstractMCMC.setprogress!(progress)`.
## Progress logging

The default value for the `progress` keyword argument is `AbstractMCMC.PROGRESS[]`, which is always set to `true` unless modified with `AbstractMCMC.setprogress!`.
For example, `setprogress!(false)` will disable all progress logging.

```@docs
AbstractMCMC.setprogress!
```

For single-chain sampling (i.e., `sample([rng,] model, sampler, N)`), as well as multiple-chain sampling with `MCMCSerial`, the `progress` keyword argument should be a `Bool`.

For multiple-chain sampling using `MCMCThreads`, there are several, more detailed, options:

- `:perchain`: create one progress bar per chain being sampled
- `:overall`: create one progress bar for the overall sampling process, which tracks the percentage of samples that have been sampled across all chains
- `:none`: do not create any progress bar
- `true` (the default): use `perchain` for 10 or fewer chains, and `overall` for more than 10 chains
- `false`: same as `none`, i.e. no progress bar

The threshold of 10 chains can be changed using `AbstractMCMC.setmaxchainsprogress!(N)`, which will cause `MCMCThreads` to use `:perchain` for `N` or fewer chains, and `:overall` for more than `N` chains.
Thus, for example, if you _always_ want to use `:overall`, you can call `AbstractMCMC.setmaxchainsprogress!(0)`.

Multiple-chain sampling using `MCMCDistributed` behaves the same as `MCMCThreads`, except that `:perchain` is not (yet?) implemented.
So, `true` always corresponds to `:overall`, and `false` corresponds to `:none`.

!!! warning "Do not override the `progress` keyword argument"
If you are implementing your own methods for `sample(...)`, you should make sure to not override the `progress` keyword argument if you want progress logging in multi-chain sampling to work correctly, as the multi-chain `sample()` call makes sure to specifically pass custom values of `progress` to the single-chain calls.

## Chains

The `chain_type` keyword argument allows to set the type of the returned chain. A common
Expand Down
1 change: 1 addition & 0 deletions src/AbstractMCMC.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ using FillArrays: FillArrays
using Distributed: Distributed
using Logging: Logging
using Random: Random
using UUIDs: UUIDs

# Reexport sample
using StatsBase: sample
Expand Down
117 changes: 103 additions & 14 deletions src/logging.jl
Original file line number Diff line number Diff line change
@@ -1,19 +1,109 @@
# avoid creating a progress bar with @withprogress if progress logging is disabled
# and add a custom progress logger if the current logger does not seem to be able to handle
# progress logs
macro ifwithprogresslogger(progress, exprs...)
"""
AbstractProgressKwarg

Abstract type representing the values that the `progress` keyword argument can
internally take for single-chain sampling.
"""
abstract type AbstractProgressKwarg end

"""
CreateNewProgressBar

Create a new logger for progress logging.
"""
struct CreateNewProgressBar{S<:AbstractString} <: AbstractProgressKwarg
name::S
uuid::UUIDs.UUID
function CreateNewProgressBar(name::AbstractString)
return new{typeof(name)}(name, UUIDs.uuid4())
end
end
function init_progress(p::CreateNewProgressBar)
ProgressLogging.@logprogress p.name nothing _id = p.uuid
end
function update_progress(p::CreateNewProgressBar, progress_frac)
ProgressLogging.@logprogress p.name progress_frac _id = p.uuid
end
function finish_progress(p::CreateNewProgressBar)
ProgressLogging.@logprogress p.name "done" _id = p.uuid
end

"""
NoLogging

Do not log progress at all.
"""
struct NoLogging <: AbstractProgressKwarg end
init_progress(::NoLogging) = nothing
update_progress(::NoLogging, ::Any) = nothing
finish_progress(::NoLogging) = nothing

"""
ExistingProgressBar

Use an existing progress bar to log progress. This is used for tracking
progress in a progress bar that has been previously generated elsewhere,
specifically, when `sample(..., MCMCThreads(), ...; progress=:perchain)` is
called. In this case we can use `@logprogress name progress_frac _id = uuid` to
log progress.
"""
struct ExistingProgressBar{S<:AbstractString} <: AbstractProgressKwarg
name::S
uuid::UUIDs.UUID
end
function init_progress(p::ExistingProgressBar)
# Hacky code to reset the start timer if called from a multi-chain sampling
# process. We need this because the progress bar is constructed in the
# multi-chain method, i.e. if we don't do this the progress bar shows the
# time elapsed since _all_ sampling began, not since the current chain
# started.
try
bartrees = Logging.current_logger().loggers[1].logger.bartrees
bar = TerminalLoggers.findbar(bartrees, p.uuid).data
bar.tfirst = time()
catch
end
ProgressLogging.@logprogress p.name nothing _id = p.uuid
end
function update_progress(p::ExistingProgressBar, progress_frac)
ProgressLogging.@logprogress p.name progress_frac _id = p.uuid
end
function finish_progress(p::ExistingProgressBar)
ProgressLogging.@logprogress p.name "done" _id = p.uuid
end

"""
ChannelProgress

Use a `Channel` to log progress. This is used for 'reporting' progress back
to the main thread or worker when using `progress=:overall` with MCMCThreads or
MCMCDistributed.

n_updates is the number of updates that each child thread is expected to report
back to the main thread.
"""
struct ChannelProgress{T<:Union{Channel{Bool},Distributed.RemoteChannel{Channel{Bool}}}} <:
AbstractProgressKwarg
channel::T
n_updates::Int
end
init_progress(::ChannelProgress) = nothing
update_progress(p::ChannelProgress, ::Any) = put!(p.channel, true)
# Note: We don't want to `put!(p.channel, false)`, because that would stop the
# channel from being used for further updates e.g. from other chains.
finish_progress(::ChannelProgress) = nothing

# Add a custom progress logger if the current logger does not seem to be able to handle
# progress logs.
macro maybewithricherlogger(expr)
return esc(
quote
if $progress
if $hasprogresslevel($Logging.current_logger())
$ProgressLogging.@withprogress $(exprs...)
else
$with_progresslogger($Base.@__MODULE__, $Logging.current_logger()) do
$ProgressLogging.@withprogress $(exprs...)
end
if !($hasprogresslevel($Logging.current_logger()))
$with_progresslogger($Base.@__MODULE__, $Logging.current_logger()) do
$(expr)
end
else
$(exprs[end])
$(expr)
end
end,
)
Expand All @@ -39,8 +129,7 @@ end
function progresslogger()
# detect if code is running under IJulia since TerminalLogger does not work with IJulia
# https://github.com/JuliaLang/IJulia.jl#detecting-that-code-is-running-under-ijulia
if (Sys.iswindows() && VERSION < v"1.5.3") ||
(isdefined(Main, :IJulia) && Main.IJulia.inited)
if (isdefined(Main, :IJulia) && Main.IJulia.inited)
return ConsoleProgressMonitor.ProgressLogger()
else
return TerminalLoggers.TerminalLogger()
Expand Down
Loading
Loading