From ec621b7873e0110d23aeba988d164e38e5b52d6a Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Wed, 8 Jan 2025 15:04:17 +0000 Subject: [PATCH 1/3] Fix indexing for chains in different threads --- Project.toml | 2 +- src/sample.jl | 22 ++++++++++++++-------- 2 files changed, 15 insertions(+), 9 deletions(-) diff --git a/Project.toml b/Project.toml index 7d7e129e..9bb2749e 100644 --- a/Project.toml +++ b/Project.toml @@ -3,7 +3,7 @@ uuid = "80f14c24-f653-4e6a-9b94-39d6b0f70001" keywords = ["markov chain monte carlo", "probabilistic programming"] license = "MIT" desc = "A lightweight interface for common MCMC methods." -version = "5.6.0" +version = "5.6.1" [deps] BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66" diff --git a/src/sample.jl b/src/sample.jl index 23246049..5b8e0524 100644 --- a/src/sample.jl +++ b/src/sample.jl @@ -391,13 +391,24 @@ function mcmcsample( # Copy the random number generator, model, and sample for each thread nchunks = min(nchains, Threads.nthreads()) - chunksize = cld(nchains, nchunks) interval = 1:nchunks # `copy` instead of `deepcopy` for RNGs: https://github.com/JuliaLang/julia/issues/42899 rngs = [copy(rng) for _ in interval] models = [deepcopy(model) for _ in interval] samplers = [deepcopy(sampler) for _ in interval] + # Distribute chains amongst the chunks. If nchains/nchunks = m with + # remainder n, then the first n chunks will have m + 1 chains, and the rest + # will have m chains. + m, n = divrem(nchains, nchunks) + chain_index_groups = UnitRange{Int}[] + current_index = 1 + for i in interval + nchains_this_chunk = i <= n ? m + 1 : m + push!(chain_index_groups, current_index:(current_index + nchains_this_chunk - 1)) + current_index += nchains_this_chunk + end + # Create a seed for each chain using the provided random number generator. seeds = rand(rng, UInt, nchains) @@ -436,13 +447,8 @@ function mcmcsample( Distributed.@async begin try - Distributed.@sync for (i, _rng, _model, _sampler) in - zip(1:nchunks, rngs, models, samplers) - chainidxs = if i == nchunks - ((i - 1) * chunksize + 1):nchains - else - ((i - 1) * chunksize + 1):(i * chunksize) - end + Distributed.@sync for (chainidxs, _rng, _model, _sampler) in + zip(chain_index_groups, rngs, models, samplers) Threads.@spawn for chainidx in chainidxs # Seed the chunk-specific random number generator with the pre-made seed. Random.seed!(_rng, seeds[chainidx]) From 81e0fe7d28905fd00feedfa3a26f383a44203873 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Wed, 8 Jan 2025 15:47:06 +0000 Subject: [PATCH 2/3] Calculate chainidxs inside the loop --- src/sample.jl | 26 ++++++++++++++------------ 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/src/sample.jl b/src/sample.jl index 5b8e0524..c2f40b11 100644 --- a/src/sample.jl +++ b/src/sample.jl @@ -397,17 +397,9 @@ function mcmcsample( models = [deepcopy(model) for _ in interval] samplers = [deepcopy(sampler) for _ in interval] - # Distribute chains amongst the chunks. If nchains/nchunks = m with - # remainder n, then the first n chunks will have m + 1 chains, and the rest - # will have m chains. + # If nchains/nchunks = m with remainder n, then the first n chunks will + # have m + 1 chains, and the rest will have m chains. m, n = divrem(nchains, nchunks) - chain_index_groups = UnitRange{Int}[] - current_index = 1 - for i in interval - nchains_this_chunk = i <= n ? m + 1 : m - push!(chain_index_groups, current_index:(current_index + nchains_this_chunk - 1)) - current_index += nchains_this_chunk - end # Create a seed for each chain using the provided random number generator. seeds = rand(rng, UInt, nchains) @@ -447,8 +439,18 @@ function mcmcsample( Distributed.@async begin try - Distributed.@sync for (chainidxs, _rng, _model, _sampler) in - zip(chain_index_groups, rngs, models, samplers) + Distributed.@sync for (i, _rng, _model, _sampler) in + zip(interval, rngs, models, samplers) + if i <= n + chainidx_hi = i * (m + 1) + nchains_chunk = m + 1 + else + chainidx_hi = n * (m + 1) + (i - n) * m + nchains_chunk = m + end + chainidx_lo = chainidx_hi - nchains_chunk + 1 + chainidxs = chainidx_lo:chainidx_hi + Threads.@spawn for chainidx in chainidxs # Seed the chunk-specific random number generator with the pre-made seed. Random.seed!(_rng, seeds[chainidx]) From 66fed7b082d2a338f0dd877ec81ae9143f47b14a Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Wed, 8 Jan 2025 18:22:08 +0000 Subject: [PATCH 3/3] Simplify expression for `chainidx_hi` Co-authored-by: David Widmann --- src/sample.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sample.jl b/src/sample.jl index c2f40b11..32aca7d6 100644 --- a/src/sample.jl +++ b/src/sample.jl @@ -445,7 +445,7 @@ function mcmcsample( chainidx_hi = i * (m + 1) nchains_chunk = m + 1 else - chainidx_hi = n * (m + 1) + (i - n) * m + chainidx_hi = i * m + n # n * (m + 1) + (i - n) * m nchains_chunk = m end chainidx_lo = chainidx_hi - nchains_chunk + 1