diff --git a/Project.toml b/Project.toml index eef2aaf5..66ada2dd 100644 --- a/Project.toml +++ b/Project.toml @@ -3,7 +3,7 @@ uuid = "80f14c24-f653-4e6a-9b94-39d6b0f70001" keywords = ["markov chain monte carlo", "probabilistic programming"] license = "MIT" desc = "A lightweight interface for common MCMC methods." -version = "5.6.3" +version = "5.7.0" [deps] BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66" @@ -18,6 +18,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" TerminalLoggers = "5d786b92-1e48-4d6f-9151-6b4477ca9bed" Transducers = "28d57a85-8fef-5791-bfe6-a80928e7c999" +UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" [compat] BangBang = "0.3.19, 0.4" @@ -29,6 +30,7 @@ ProgressLogging = "0.1" StatsBase = "0.32, 0.33, 0.34" TerminalLoggers = "0.1" Transducers = "0.4.30" +UUIDs = "<0.0.1, 1" julia = "1.6" [extras] diff --git a/docs/src/api.md b/docs/src/api.md index da92fd37..8978e03a 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -68,7 +68,7 @@ AbstractMCMC.MCMCSerial ## Common keyword arguments Common keyword arguments for regular and parallel sampling are: -- `progress` (default: `AbstractMCMC.PROGRESS[]` which is `true` initially): toggles progress logging +- `progress` (default: `AbstractMCMC.PROGRESS[]` which is `true` initially): toggles progress logging. See the section on [Progress logging](#progress-logging) below for more details. - `chain_type` (default: `Any`): determines the type of the returned chain - `callback` (default: `nothing`): if `callback !== nothing`, then `callback(rng, model, sampler, sample, iteration)` is called after every sampling step, @@ -90,12 +90,34 @@ However, multiple packages such as [EllipticalSliceSampling.jl](https://github.c To ensure that sampling multiple chains "just works" when sampling of a single chain is implemented, [we decided to support `initial_params` in the default implementations of the ensemble methods](https://github.com/TuringLang/AbstractMCMC.jl/pull/94): - `initial_params` (default: `nothing`): if `initial_params isa AbstractArray`, then the `i`th element of `initial_params` is used as initial parameters of the `i`th chain. If one wants to use the same initial parameters `x` for every chain, one can specify e.g. `initial_params = FillArrays.Fill(x, N)`. -Progress logging can be enabled and disabled globally with `AbstractMCMC.setprogress!(progress)`. +## Progress logging + +The default value for the `progress` keyword argument is `AbstractMCMC.PROGRESS[]`, which is always set to `true` unless modified with `AbstractMCMC.setprogress!`. +For example, `setprogress!(false)` will disable all progress logging. ```@docs AbstractMCMC.setprogress! ``` +For single-chain sampling (i.e., `sample([rng,] model, sampler, N)`), as well as multiple-chain sampling with `MCMCSerial`, the `progress` keyword argument should be a `Bool`. + +For multiple-chain sampling using `MCMCThreads`, there are several, more detailed, options: + +- `:perchain`: create one progress bar per chain being sampled +- `:overall`: create one progress bar for the overall sampling process, which tracks the percentage of samples that have been sampled across all chains +- `:none`: do not create any progress bar +- `true` (the default): use `perchain` for 10 or fewer chains, and `overall` for more than 10 chains +- `false`: same as `none`, i.e. no progress bar + +The threshold of 10 chains can be changed using `AbstractMCMC.setmaxchainsprogress!(N)`, which will cause `MCMCThreads` to use `:perchain` for `N` or fewer chains, and `:overall` for more than `N` chains. +Thus, for example, if you _always_ want to use `:overall`, you can call `AbstractMCMC.setmaxchainsprogress!(0)`. + +Multiple-chain sampling using `MCMCDistributed` behaves the same as `MCMCThreads`, except that `:perchain` is not (yet?) implemented. +So, `true` always corresponds to `:overall`, and `false` corresponds to `:none`. + +!!! warning "Do not override the `progress` keyword argument" + If you are implementing your own methods for `sample(...)`, you should make sure to not override the `progress` keyword argument if you want progress logging in multi-chain sampling to work correctly, as the multi-chain `sample()` call makes sure to specifically pass custom values of `progress` to the single-chain calls. + ## Chains The `chain_type` keyword argument allows to set the type of the returned chain. A common diff --git a/src/AbstractMCMC.jl b/src/AbstractMCMC.jl index 8c1d5610..0e046927 100644 --- a/src/AbstractMCMC.jl +++ b/src/AbstractMCMC.jl @@ -13,6 +13,7 @@ using FillArrays: FillArrays using Distributed: Distributed using Logging: Logging using Random: Random +using UUIDs: UUIDs # Reexport sample using StatsBase: sample diff --git a/src/logging.jl b/src/logging.jl index 04c41187..9eebc0c0 100644 --- a/src/logging.jl +++ b/src/logging.jl @@ -1,19 +1,109 @@ -# avoid creating a progress bar with @withprogress if progress logging is disabled -# and add a custom progress logger if the current logger does not seem to be able to handle -# progress logs -macro ifwithprogresslogger(progress, exprs...) +""" + AbstractProgressKwarg + +Abstract type representing the values that the `progress` keyword argument can +internally take for single-chain sampling. +""" +abstract type AbstractProgressKwarg end + +""" + CreateNewProgressBar + +Create a new logger for progress logging. +""" +struct CreateNewProgressBar{S<:AbstractString} <: AbstractProgressKwarg + name::S + uuid::UUIDs.UUID + function CreateNewProgressBar(name::AbstractString) + return new{typeof(name)}(name, UUIDs.uuid4()) + end +end +function init_progress(p::CreateNewProgressBar) + ProgressLogging.@logprogress p.name nothing _id = p.uuid +end +function update_progress(p::CreateNewProgressBar, progress_frac) + ProgressLogging.@logprogress p.name progress_frac _id = p.uuid +end +function finish_progress(p::CreateNewProgressBar) + ProgressLogging.@logprogress p.name "done" _id = p.uuid +end + +""" + NoLogging + +Do not log progress at all. +""" +struct NoLogging <: AbstractProgressKwarg end +init_progress(::NoLogging) = nothing +update_progress(::NoLogging, ::Any) = nothing +finish_progress(::NoLogging) = nothing + +""" + ExistingProgressBar + +Use an existing progress bar to log progress. This is used for tracking +progress in a progress bar that has been previously generated elsewhere, +specifically, when `sample(..., MCMCThreads(), ...; progress=:perchain)` is +called. In this case we can use `@logprogress name progress_frac _id = uuid` to +log progress. +""" +struct ExistingProgressBar{S<:AbstractString} <: AbstractProgressKwarg + name::S + uuid::UUIDs.UUID +end +function init_progress(p::ExistingProgressBar) + # Hacky code to reset the start timer if called from a multi-chain sampling + # process. We need this because the progress bar is constructed in the + # multi-chain method, i.e. if we don't do this the progress bar shows the + # time elapsed since _all_ sampling began, not since the current chain + # started. + try + bartrees = Logging.current_logger().loggers[1].logger.bartrees + bar = TerminalLoggers.findbar(bartrees, p.uuid).data + bar.tfirst = time() + catch + end + ProgressLogging.@logprogress p.name nothing _id = p.uuid +end +function update_progress(p::ExistingProgressBar, progress_frac) + ProgressLogging.@logprogress p.name progress_frac _id = p.uuid +end +function finish_progress(p::ExistingProgressBar) + ProgressLogging.@logprogress p.name "done" _id = p.uuid +end + +""" + ChannelProgress + +Use a `Channel` to log progress. This is used for 'reporting' progress back +to the main thread or worker when using `progress=:overall` with MCMCThreads or +MCMCDistributed. + +n_updates is the number of updates that each child thread is expected to report +back to the main thread. +""" +struct ChannelProgress{T<:Union{Channel{Bool},Distributed.RemoteChannel{Channel{Bool}}}} <: + AbstractProgressKwarg + channel::T + n_updates::Int +end +init_progress(::ChannelProgress) = nothing +update_progress(p::ChannelProgress, ::Any) = put!(p.channel, true) +# Note: We don't want to `put!(p.channel, false)`, because that would stop the +# channel from being used for further updates e.g. from other chains. +finish_progress(::ChannelProgress) = nothing + +# Add a custom progress logger if the current logger does not seem to be able to handle +# progress logs. +macro maybewithricherlogger(expr) return esc( quote - if $progress - if $hasprogresslevel($Logging.current_logger()) - $ProgressLogging.@withprogress $(exprs...) - else - $with_progresslogger($Base.@__MODULE__, $Logging.current_logger()) do - $ProgressLogging.@withprogress $(exprs...) - end + if !($hasprogresslevel($Logging.current_logger())) + $with_progresslogger($Base.@__MODULE__, $Logging.current_logger()) do + $(expr) end else - $(exprs[end]) + $(expr) end end, ) @@ -39,8 +129,7 @@ end function progresslogger() # detect if code is running under IJulia since TerminalLogger does not work with IJulia # https://github.com/JuliaLang/IJulia.jl#detecting-that-code-is-running-under-ijulia - if (Sys.iswindows() && VERSION < v"1.5.3") || - (isdefined(Main, :IJulia) && Main.IJulia.inited) + if (isdefined(Main, :IJulia) && Main.IJulia.inited) return ConsoleProgressMonitor.ProgressLogger() else return TerminalLoggers.TerminalLogger() diff --git a/src/sample.jl b/src/sample.jl index 01a2006a..220cf38e 100644 --- a/src/sample.jl +++ b/src/sample.jl @@ -1,5 +1,6 @@ # Default implementations of `sample`. const PROGRESS = Ref(true) +const MAX_CHAINS_PROGRESS = Ref(10) _pluralise(n; singular="", plural="s") = n == 1 ? singular : plural @@ -17,6 +18,25 @@ function setprogress!(progress::Bool; silent::Bool=false) return progress end +""" + setmaxchainsprogress!(max_chains::Int, silent::Bool=false) + +Set the maximum number of chains to display progress bars for when sampling +multiple chains at once (if progress logging is enabled). Above this limit, no +progress bars are displayed for individual chains; instead, a single progress +bar is displayed for the entire sampling process. +""" +function setmaxchainsprogress!(max_chains::Int, silent::Bool=false) + if max_chains < 0 + throw(ArgumentError("maximum number of chains must be non-negative")) + end + if !silent + @info "AbstractMCMC: maximum number of per-chain progress bars set to $max_chains" + end + MAX_CHAINS_PROGRESS[] = max_chains + return nothing +end + function StatsBase.sample( model_or_logdensity, sampler::AbstractSampler, N_or_isdone; kwargs... ) @@ -111,7 +131,7 @@ function mcmcsample( model::AbstractModel, sampler::AbstractSampler, N::Integer; - progress=PROGRESS[], + progress::Union{Bool,<:AbstractProgressKwarg}=PROGRESS[], progressname="Sampling", callback=nothing, num_warmup::Int=0, @@ -132,6 +152,13 @@ function mcmcsample( ArgumentError("number of warm-up samples exceeds the total number of samples") ) + # Initialise progress bar + if progress === true + progress = CreateNewProgressBar(progressname) + elseif progress === false + progress = NoLogging() + end + # Determine how many samples to drop from `num_warmup` and the # main sampling process before we start saving samples. discard_from_warmup = min(num_warmup, discard_initial) @@ -141,13 +168,14 @@ function mcmcsample( start = time() local state - @ifwithprogresslogger progress name = progressname begin - # Determine threshold values for progress logging - # (one update per 0.5% of progress) - if progress - threshold = Ntotal ÷ 200 - next_update = threshold - end + @maybewithricherlogger begin + init_progress(progress) + # Determine threshold values for progress logging (by default, one + # update per 0.5% of progress, unless this has been passed in + # explicitly) + n_updates = progress isa ChannelProgress ? progress.n_updates : 200 + threshold = Ntotal / n_updates + next_update = threshold # Obtain the initial sample and state. sample, state = if num_warmup > 0 @@ -164,11 +192,11 @@ function mcmcsample( end end - # Update the progress bar. + # Start the progress bar. itotal = 1 - if progress && itotal >= next_update - ProgressLogging.@logprogress itotal / Ntotal - next_update = itotal + threshold + if itotal >= next_update + update_progress(progress, itotal / Ntotal) + next_update += threshold end # Discard initial samples. @@ -181,9 +209,10 @@ function mcmcsample( end # Update the progress bar. - if progress && (itotal += 1) >= next_update - ProgressLogging.@logprogress itotal / Ntotal - next_update = itotal + threshold + itotal += 1 + if itotal >= next_update + update_progress(progress, itotal / Ntotal) + next_update += threshold end end @@ -206,9 +235,10 @@ function mcmcsample( end # Update progress bar. - if progress && (itotal += 1) >= next_update - ProgressLogging.@logprogress itotal / Ntotal - next_update = itotal + threshold + itotal += 1 + if itotal >= next_update + update_progress(progress, itotal / Ntotal) + next_update += threshold end end @@ -227,11 +257,13 @@ function mcmcsample( samples = save!!(samples, sample, i, model, sampler, N; kwargs...) # Update the progress bar. - if progress && (itotal += 1) >= next_update - ProgressLogging.@logprogress itotal / Ntotal - next_update = itotal + threshold + itotal += 1 + if itotal >= next_update + update_progress(progress, itotal / Ntotal) + next_update += threshold end end + finish_progress(progress) end # Get the sample stop time. @@ -258,7 +290,7 @@ function mcmcsample( sampler::AbstractSampler, isdone; chain_type::Type=Any, - progress=PROGRESS[], + progress::Bool=PROGRESS[], progressname="Convergence sampling", callback=nothing, num_warmup=0, @@ -273,6 +305,13 @@ function mcmcsample( num_warmup >= 0 || throw(ArgumentError("number of warm-up samples must be non-negative")) + # Initialise progress bar + if progress === true + progress = CreateNewProgressBar(progressname) + elseif progress === false + progress = NoLogging() + end + # Determine how many samples to drop from `num_warmup` and the # main sampling process before we start saving samples. discard_from_warmup = min(num_warmup, discard_initial) @@ -282,7 +321,8 @@ function mcmcsample( start = time() local state - @ifwithprogresslogger progress name = progressname begin + @maybewithricherlogger begin + init_progress(progress) # Obtain the initial sample and state. sample, state = if num_warmup > 0 if initial_state === nothing @@ -345,6 +385,7 @@ function mcmcsample( # Increment iteration counter. i += 1 end + finish_progress(progress) end # Get the sample stop time. @@ -373,7 +414,7 @@ function mcmcsample( ::MCMCThreads, N::Integer, nchains::Integer; - progress=PROGRESS[], + progress::Union{Bool,Symbol}=PROGRESS[], progressname="Sampling ($(min(nchains, Threads.nthreads())) thread$(_pluralise(min(nchains, Threads.nthreads()))))", initial_params=nothing, initial_state=nothing, @@ -389,6 +430,18 @@ function mcmcsample( @warn "Number of chains ($nchains) is greater than number of samples per chain ($N)" end + # Determine default progress bar style. + if progress == true + progress = nchains > MAX_CHAINS_PROGRESS[] ? :overall : :perchain + elseif progress == false + progress = :none + end + progress in [:overall, :perchain, :none] || throw( + ArgumentError( + "`progress` for MCMCThreads must be `:overall`, `:perchain`, `:none`, or a boolean", + ), + ) + # Copy the random number generator, model, and sample for each thread nchunks = min(nchains, Threads.nthreads()) interval = 1:nchunks @@ -411,29 +464,65 @@ function mcmcsample( # Set up a chains vector. chains = Vector{Any}(undef, nchains) - @ifwithprogresslogger progress name = progressname begin - # Create a channel for progress logging. - if progress - channel = Channel{Bool}(length(interval)) + @maybewithricherlogger begin + if progress == :perchain + # Create a channel for each chain to report back to when it + # finishes sampling. + progress_channel = Channel{Bool}(nchunks) + # This is the 'overall' progress bar which tracks the number of + # chains that have completed. Note that this progress bar is backed + # by a channel, but it is not itself a ChannelProgress (because + # ChannelProgress doesn't come with a progress bar). + overall_progress_bar = CreateNewProgressBar(progressname) + init_progress(overall_progress_bar) + # These are the per-chain progress bars. We generate `nchains` + # independent UUIDs for each progress bar + child_progresses = [ + ExistingProgressBar("Chain $i/$nchains", UUIDs.uuid4()) for i in 1:nchains + ] + # Start the per-chain progress bars (but in reverse order, because + # ProgressLogging prints from the bottom up, and we want chain 1 to + # show up at the top) + for child_progress in reverse(child_progresses) + init_progress(child_progress) + end + updates_per_chain = nothing + elseif progress == :overall + # Just a single progress bar for the entire sampling, but instead + # of tracking each chain as it comes in, we track each sample as it + # comes in. This allows us to have more granular progress updates. + progress_channel = Channel{Bool}(nchains) + overall_progress_bar = CreateNewProgressBar(progressname) + # If we have many chains and many samples, we don't want to force + # each chain to report back to the main thread for each sample, as + # this would cause serious performance issues due to lock conflicts. + # In the overall progress bar we only expect 200 updates (i.e., one + # update per 0.5%). To avoid possible throttling issues we ask for + # twice the amount needed per chain, which doesn't cause a real + # performance hit. + updates_per_chain = max(1, 400 ÷ nchains) end Distributed.@sync begin - if progress - # Update the progress bar. + if progress != :none + # This task updates the progress bar Distributed.@async begin + # Total number of updates (across all chains) + Ntotal = progress == :overall ? nchains * updates_per_chain : nchains # Determine threshold values for progress logging # (one update per 0.5% of progress) - threshold = nchains ÷ 200 - nextprogresschains = threshold - - progresschains = 0 - while take!(channel) - progresschains += 1 - if progresschains >= nextprogresschains - ProgressLogging.@logprogress progresschains / nchains - nextprogresschains = progresschains + threshold + threshold = Ntotal / 200 + next_update = threshold + + itotal = 0 + while take!(progress_channel) + itotal += 1 + if itotal >= next_update + update_progress(overall_progress_bar, itotal / Ntotal) + next_update += threshold end end + finish_progress(overall_progress_bar) end end @@ -455,13 +544,22 @@ function mcmcsample( # Seed the chunk-specific random number generator with the pre-made seed. Random.seed!(_rng, seeds[chainidx]) + # Determine how to monitor progress for the child chains. + child_progress = if progress == :none + false + elseif progress == :overall + ChannelProgress(progress_channel, updates_per_chain) + elseif progress == :perchain + child_progresses[chainidx] # <- isa ExistingProgressBar + end + # Sample a chain and save it to the vector. chains[chainidx] = StatsBase.sample( _rng, _model, _sampler, N; - progress=false, + progress=child_progress, initial_params=if initial_params === nothing nothing else @@ -475,13 +573,33 @@ function mcmcsample( kwargs..., ) - # Update the progress bar. - progress && put!(channel, true) + # Update the progress bars. + if progress == :perchain + # Tell the 'main' progress bar that this chain is done. + put!(progress_channel, true) + # Conclude the per-chain progress bar. + finish_progress(child_progresses[chainidx]) + end + # Note that if progress == :overall, we don't need to do anything + # because progress on that bar is triggered by + # samples being obtained rather than chains being + # completed. end end finally - # Stop updating the progress bar. - progress && put!(channel, false) + if progress == :perchain + # Stop updating the main progress bar (either if sampling + # is done, or if an error occurs). + put!(progress_channel, false) + # Additionally stop the per-chain progress bars + for child_progress in child_progresses + finish_progress(child_progress) + end + elseif progress == :overall + # Stop updating the main progress bar (either if sampling + # is done, or if an error occurs). + put!(progress_channel, false) + end end end end @@ -498,7 +616,7 @@ function mcmcsample( ::MCMCDistributed, N::Integer, nchains::Integer; - progress=PROGRESS[], + progress::Union{Bool,Symbol}=PROGRESS[], progressname="Sampling ($(Distributed.nworkers()) process$(_pluralise(Distributed.nworkers(); plural="es")))", initial_params=nothing, initial_state=nothing, @@ -514,6 +632,19 @@ function mcmcsample( @warn "Number of chains ($nchains) is greater than number of samples per chain ($N)" end + # Determine default progress bar style. Note that for MCMCDistributed(), + # :perchain isn't implemented. + if progress == true + progress = :overall + elseif progress == false + progress = :none + end + progress in [:overall, :none] || throw( + ArgumentError( + "`progress` for MCMCDistributed must be `:overall`, `:none`, or a boolean" + ), + ) + # Ensure that initial parameters and states are `nothing` or of the correct length check_initial_params(initial_params, nchains) check_initial_state(initial_state, nchains) @@ -530,35 +661,52 @@ function mcmcsample( pool = Distributed.CachingPool(Distributed.workers()) local chains - @ifwithprogresslogger progress name = progressname begin - # Create a channel for progress logging. - if progress - channel = Distributed.RemoteChannel(() -> Channel{Bool}(Distributed.nworkers())) + @maybewithricherlogger begin + # Set up progress logging. + if progress == :overall + # Just a single progress bar for the entire sampling, but instead + # of tracking each chain as it comes in, we track each sample as it + # comes in. This allows us to have more granular progress updates. + chan = Channel{Bool}(Distributed.nworkers()) + progress_channel = Distributed.RemoteChannel(() -> chan) + overall_progress_bar = CreateNewProgressBar(progressname) + init_progress(overall_progress_bar) + # See MCMCThreads method for the rationale behind updates_per_chain. + updates_per_chain = max(1, 400 ÷ nchains) + child_progresses = [ + ChannelProgress(progress_channel, updates_per_chain) for _ in 1:nchains + ] + elseif progress == :none + child_progresses = [false for _ in 1:nchains] end Distributed.@sync begin - if progress - # Update the progress bar. + if progress == :overall + # This task updates the progress bar Distributed.@async begin # Determine threshold values for progress logging # (one update per 0.5% of progress) - threshold = nchains ÷ 200 - nextprogresschains = threshold - - progresschains = 0 - while take!(channel) - progresschains += 1 - if progresschains >= nextprogresschains - ProgressLogging.@logprogress progresschains / nchains - nextprogresschains = progresschains + threshold + Ntotal = nchains * updates_per_chain + threshold = Ntotal / 200 + next_update = threshold + + itotal = 0 + while take!(progress_channel) + itotal += 1 + if itotal >= next_update + update_progress(overall_progress_bar, itotal / Ntotal) + next_update += threshold end end + finish_progress(overall_progress_bar) end end Distributed.@async begin try - function sample_chain(seed, initial_params, initial_state) + function sample_chain( + seed, initial_params, initial_state, child_progress + ) # Seed a new random number generator with the pre-made seed. Random.seed!(rng, seed) @@ -568,24 +716,29 @@ function mcmcsample( model, sampler, N; - progress=false, + progress=child_progress, initial_params=initial_params, initial_state=initial_state, kwargs..., ) - # Update the progress bar. - progress && put!(channel, true) - # Return the new chain. return chain end chains = Distributed.pmap( - sample_chain, pool, seeds, _initial_params, _initial_state + sample_chain, + pool, + seeds, + _initial_params, + _initial_state, + child_progresses, ) finally - # Stop updating the progress bar. - progress && put!(channel, false) + if progress == :overall + # Stop updating the main progress bar (either if sampling + # is done, or if an error occurs). + put!(progress_channel, false) + end end end end diff --git a/test/sample.jl b/test/sample.jl index e036f558..5af2aa74 100644 --- a/test/sample.jl +++ b/test/sample.jl @@ -10,8 +10,7 @@ @test length(LOGGERS) == 1 logger = first(LOGGERS) @test logger isa TeeLogger - @test logger.loggers[1].logger isa - (Sys.iswindows() && VERSION < v"1.5.3" ? ProgressLogger : TerminalLogger) + @test logger.loggers[1].logger isa TerminalLogger @test logger.loggers[2].logger === CURRENT_LOGGER @test Logging.current_logger() === CURRENT_LOGGER