Skip to content

Commit 82f02f1

Browse files
Fix indexing for chains in different threads (#154)
* Fix indexing for chains in different threads * Calculate chainidxs inside the loop * Simplify expression for `chainidx_hi` Co-authored-by: David Widmann <devmotion@users.noreply.github.com> --------- Co-authored-by: David Widmann <devmotion@users.noreply.github.com>
1 parent 5a3b155 commit 82f02f1

File tree

2 files changed

+14
-6
lines changed

2 files changed

+14
-6
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ uuid = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
33
keywords = ["markov chain monte carlo", "probabilistic programming"]
44
license = "MIT"
55
desc = "A lightweight interface for common MCMC methods."
6-
version = "5.6.0"
6+
version = "5.6.1"
77

88
[deps]
99
BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66"

src/sample.jl

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -391,13 +391,16 @@ function mcmcsample(
391391

392392
# Copy the random number generator, model, and sample for each thread
393393
nchunks = min(nchains, Threads.nthreads())
394-
chunksize = cld(nchains, nchunks)
395394
interval = 1:nchunks
396395
# `copy` instead of `deepcopy` for RNGs: https://github.com/JuliaLang/julia/issues/42899
397396
rngs = [copy(rng) for _ in interval]
398397
models = [deepcopy(model) for _ in interval]
399398
samplers = [deepcopy(sampler) for _ in interval]
400399

400+
# If nchains/nchunks = m with remainder n, then the first n chunks will
401+
# have m + 1 chains, and the rest will have m chains.
402+
m, n = divrem(nchains, nchunks)
403+
401404
# Create a seed for each chain using the provided random number generator.
402405
seeds = rand(rng, UInt, nchains)
403406

@@ -437,12 +440,17 @@ function mcmcsample(
437440
Distributed.@async begin
438441
try
439442
Distributed.@sync for (i, _rng, _model, _sampler) in
440-
zip(1:nchunks, rngs, models, samplers)
441-
chainidxs = if i == nchunks
442-
((i - 1) * chunksize + 1):nchains
443+
zip(interval, rngs, models, samplers)
444+
if i <= n
445+
chainidx_hi = i * (m + 1)
446+
nchains_chunk = m + 1
443447
else
444-
((i - 1) * chunksize + 1):(i * chunksize)
448+
chainidx_hi = i * m + n # n * (m + 1) + (i - n) * m
449+
nchains_chunk = m
445450
end
451+
chainidx_lo = chainidx_hi - nchains_chunk + 1
452+
chainidxs = chainidx_lo:chainidx_hi
453+
446454
Threads.@spawn for chainidx in chainidxs
447455
# Seed the chunk-specific random number generator with the pre-made seed.
448456
Random.seed!(_rng, seeds[chainidx])

0 commit comments

Comments
 (0)