@@ -77,18 +77,23 @@ function AbstractMCMC.sample(
77
77
)
78
78
# LDF needs to be set with SamplingContext, or else samplers cannot
79
79
# overload the tilde-pipeline.
80
- if ! ( ldf. context isa SamplingContext)
81
- ldf = LogDensityFunction (
82
- ldf . model, ldf . varinfo, SamplingContext (rng, spl); adtype = ldf . adtype
83
- )
80
+ ctx = if ldf. context isa SamplingContext
81
+ ldf. context
82
+ else
83
+ SamplingContext (rng, spl )
84
84
end
85
85
# Note that, in particular, sampling can mutate the variables in the LDF's
86
86
# varinfo (because it ultimately ends up calling `evaluate!!(ldf.model,
87
87
# ldf.varinfo)`. Furthermore, the first call to `AbstractMCMC.step` assumes
88
88
# that the parameters in the LDF are the initial parameters. So, we need to
89
- # deepcopy the LDF here to ensure that sample(rng, ldf, ...) is
89
+ # deepcopy the varinfo here to ensure that sample(rng, ldf, ...) is
90
90
# reproducible.
91
- ldf = deepcopy (ldf)
91
+ vi = deepcopy (ldf. varinfo)
92
+ # TODO (penelopeysm): Unsure if model needes to be deepcopied as well.
93
+ # Note that deepcopying the entire LDF is risky as it may include e.g.
94
+ # Mooncake or Enzyme types that don't deepcopy well. I ran into an issue
95
+ # where Mooncake errored when deepcopying an LDF.
96
+ ldf = LogDensityFunction (ldf. model, vi, ctx; adtype= ldf. adtype)
92
97
# TODO : Right now, only generic checks are run. We could in principle
93
98
# specialise this to check for e.g. discrete variables with HMC
94
99
check_model && DynamicPPL. check_model (ldf. model; error_on_failure= true )
@@ -156,18 +161,23 @@ function AbstractMCMC.sample(
156
161
)
157
162
# LDF needs to be set with SamplingContext, or else samplers cannot
158
163
# overload the tilde-pipeline.
159
- if ! ( ldf. context isa SamplingContext)
160
- ldf = LogDensityFunction (
161
- ldf . model, ldf . varinfo, SamplingContext (rng, spl); adtype = ldf . adtype
162
- )
164
+ ctx = if ldf. context isa SamplingContext
165
+ ldf. context
166
+ else
167
+ SamplingContext (rng, spl )
163
168
end
164
169
# Note that, in particular, sampling can mutate the variables in the LDF's
165
170
# varinfo (because it ultimately ends up calling `evaluate!!(ldf.model,
166
171
# ldf.varinfo)`. Furthermore, the first call to `AbstractMCMC.step` assumes
167
172
# that the parameters in the LDF are the initial parameters. So, we need to
168
- # deepcopy the LDF here to ensure that sample(rng, ldf, ...) is
173
+ # deepcopy the varinfo here to ensure that sample(rng, ldf, ...) is
169
174
# reproducible.
170
- ldf = deepcopy (ldf)
175
+ vi = deepcopy (ldf. varinfo)
176
+ # TODO (penelopeysm): Unsure if model needes to be deepcopied as well.
177
+ # Note that deepcopying the entire LDF is risky as it may include e.g.
178
+ # Mooncake or Enzyme types that don't deepcopy well. I ran into an issue
179
+ # where Mooncake errored when deepcopying an LDF.
180
+ ldf = LogDensityFunction (ldf. model, vi, ctx; adtype= ldf. adtype)
171
181
# TODO : Right now, only generic checks are run. We could in principle
172
182
# specialise this to check for e.g. discrete variables with HMC
173
183
check_model && DynamicPPL. check_model (ldf. model; error_on_failure= true )
0 commit comments