Skip to content

Commit e08f548

Browse files
committed
Fix deepcopying
1 parent e4cb590 commit e08f548

File tree

1 file changed

+22
-12
lines changed

1 file changed

+22
-12
lines changed

src/mcmc/abstractmcmc.jl

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -77,18 +77,23 @@ function AbstractMCMC.sample(
7777
)
7878
# LDF needs to be set with SamplingContext, or else samplers cannot
7979
# overload the tilde-pipeline.
80-
if !(ldf.context isa SamplingContext)
81-
ldf = LogDensityFunction(
82-
ldf.model, ldf.varinfo, SamplingContext(rng, spl); adtype=ldf.adtype
83-
)
80+
ctx = if ldf.context isa SamplingContext
81+
ldf.context
82+
else
83+
SamplingContext(rng, spl)
8484
end
8585
# Note that, in particular, sampling can mutate the variables in the LDF's
8686
# varinfo (because it ultimately ends up calling `evaluate!!(ldf.model,
8787
# ldf.varinfo)`. Furthermore, the first call to `AbstractMCMC.step` assumes
8888
# that the parameters in the LDF are the initial parameters. So, we need to
89-
# deepcopy the LDF here to ensure that sample(rng, ldf, ...) is
89+
# deepcopy the varinfo here to ensure that sample(rng, ldf, ...) is
9090
# reproducible.
91-
ldf = deepcopy(ldf)
91+
vi = deepcopy(ldf.varinfo)
92+
# TODO(penelopeysm): Unsure if model needes to be deepcopied as well.
93+
# Note that deepcopying the entire LDF is risky as it may include e.g.
94+
# Mooncake or Enzyme types that don't deepcopy well. I ran into an issue
95+
# where Mooncake errored when deepcopying an LDF.
96+
ldf = LogDensityFunction(ldf.model, vi, ctx; adtype=ldf.adtype)
9297
# TODO: Right now, only generic checks are run. We could in principle
9398
# specialise this to check for e.g. discrete variables with HMC
9499
check_model && DynamicPPL.check_model(ldf.model; error_on_failure=true)
@@ -156,18 +161,23 @@ function AbstractMCMC.sample(
156161
)
157162
# LDF needs to be set with SamplingContext, or else samplers cannot
158163
# overload the tilde-pipeline.
159-
if !(ldf.context isa SamplingContext)
160-
ldf = LogDensityFunction(
161-
ldf.model, ldf.varinfo, SamplingContext(rng, spl); adtype=ldf.adtype
162-
)
164+
ctx = if ldf.context isa SamplingContext
165+
ldf.context
166+
else
167+
SamplingContext(rng, spl)
163168
end
164169
# Note that, in particular, sampling can mutate the variables in the LDF's
165170
# varinfo (because it ultimately ends up calling `evaluate!!(ldf.model,
166171
# ldf.varinfo)`. Furthermore, the first call to `AbstractMCMC.step` assumes
167172
# that the parameters in the LDF are the initial parameters. So, we need to
168-
# deepcopy the LDF here to ensure that sample(rng, ldf, ...) is
173+
# deepcopy the varinfo here to ensure that sample(rng, ldf, ...) is
169174
# reproducible.
170-
ldf = deepcopy(ldf)
175+
vi = deepcopy(ldf.varinfo)
176+
# TODO(penelopeysm): Unsure if model needes to be deepcopied as well.
177+
# Note that deepcopying the entire LDF is risky as it may include e.g.
178+
# Mooncake or Enzyme types that don't deepcopy well. I ran into an issue
179+
# where Mooncake errored when deepcopying an LDF.
180+
ldf = LogDensityFunction(ldf.model, vi, ctx; adtype=ldf.adtype)
171181
# TODO: Right now, only generic checks are run. We could in principle
172182
# specialise this to check for e.g. discrete variables with HMC
173183
check_model && DynamicPPL.check_model(ldf.model; error_on_failure=true)

0 commit comments

Comments
 (0)