Skip to content

Commit 85b1997

Browse files
committed
Fix reproducibility of sample(ldf, ...) with deepcopy
1 parent 71a8cf2 commit 85b1997

File tree

1 file changed

+28
-0
lines changed

1 file changed

+28
-0
lines changed

src/mcmc/abstractmcmc.jl

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,20 @@ function AbstractMCMC.sample(
7979
initial_state=DynamicPPL.loadstate(resume_from),
8080
kwargs...,
8181
)
82+
# LDF needs to be set with SamplingContext, or else samplers cannot
83+
# overload the tilde-pipeline.
84+
if !(ldf.context isa SamplingContext)
85+
ldf = LogDensityFunction(
86+
ldf.model, ldf.varinfo, SamplingContext(rng, spl); adtype=ldf.adtype
87+
)
88+
end
89+
# Note that, in particular, sampling can mutate the variables in the LDF's
90+
# varinfo (because it ultimately ends up calling `evaluate!!(ldf.model,
91+
# ldf.varinfo)`. Furthermore, the first call to `AbstractMCMC.step` assumes
92+
# that the parameters in the LDF are the initial parameters. So, we need to
93+
# deepcopy the LDF here to ensure that sample(rng, ldf, ...) is
94+
# reproducible.
95+
ldf = deepcopy(ldf)
8296
# TODO: Right now, only generic checks are run. We could in principle
8397
# specialise this to check for e.g. discrete variables with HMC
8498
check_model && DynamicPPL.check_model(ldf.model; error_on_failure=true)
@@ -144,6 +158,20 @@ function AbstractMCMC.sample(
144158
initial_state=DynamicPPL.loadstate(resume_from),
145159
kwargs...,
146160
)
161+
# LDF needs to be set with SamplingContext, or else samplers cannot
162+
# overload the tilde-pipeline.
163+
if !(ldf.context isa SamplingContext)
164+
ldf = LogDensityFunction(
165+
ldf.model, ldf.varinfo, SamplingContext(rng, spl); adtype=ldf.adtype
166+
)
167+
end
168+
# Note that, in particular, sampling can mutate the variables in the LDF's
169+
# varinfo (because it ultimately ends up calling `evaluate!!(ldf.model,
170+
# ldf.varinfo)`. Furthermore, the first call to `AbstractMCMC.step` assumes
171+
# that the parameters in the LDF are the initial parameters. So, we need to
172+
# deepcopy the LDF here to ensure that sample(rng, ldf, ...) is
173+
# reproducible.
174+
ldf = deepcopy(ldf)
147175
# TODO: Right now, only generic checks are run. We could in principle
148176
# specialise this to check for e.g. discrete variables with HMC
149177
check_model && DynamicPPL.check_model(ldf.model; error_on_failure=true)

0 commit comments

Comments
 (0)