@@ -79,6 +79,20 @@ function AbstractMCMC.sample(
79
79
initial_state= DynamicPPL. loadstate (resume_from),
80
80
kwargs... ,
81
81
)
82
+ # LDF needs to be set with SamplingContext, or else samplers cannot
83
+ # overload the tilde-pipeline.
84
+ if ! (ldf. context isa SamplingContext)
85
+ ldf = LogDensityFunction (
86
+ ldf. model, ldf. varinfo, SamplingContext (rng, spl); adtype= ldf. adtype
87
+ )
88
+ end
89
+ # Note that, in particular, sampling can mutate the variables in the LDF's
90
+ # varinfo (because it ultimately ends up calling `evaluate!!(ldf.model,
91
+ # ldf.varinfo)`. Furthermore, the first call to `AbstractMCMC.step` assumes
92
+ # that the parameters in the LDF are the initial parameters. So, we need to
93
+ # deepcopy the LDF here to ensure that sample(rng, ldf, ...) is
94
+ # reproducible.
95
+ ldf = deepcopy (ldf)
82
96
# TODO : Right now, only generic checks are run. We could in principle
83
97
# specialise this to check for e.g. discrete variables with HMC
84
98
check_model && DynamicPPL. check_model (ldf. model; error_on_failure= true )
@@ -144,6 +158,20 @@ function AbstractMCMC.sample(
144
158
initial_state= DynamicPPL. loadstate (resume_from),
145
159
kwargs... ,
146
160
)
161
+ # LDF needs to be set with SamplingContext, or else samplers cannot
162
+ # overload the tilde-pipeline.
163
+ if ! (ldf. context isa SamplingContext)
164
+ ldf = LogDensityFunction (
165
+ ldf. model, ldf. varinfo, SamplingContext (rng, spl); adtype= ldf. adtype
166
+ )
167
+ end
168
+ # Note that, in particular, sampling can mutate the variables in the LDF's
169
+ # varinfo (because it ultimately ends up calling `evaluate!!(ldf.model,
170
+ # ldf.varinfo)`. Furthermore, the first call to `AbstractMCMC.step` assumes
171
+ # that the parameters in the LDF are the initial parameters. So, we need to
172
+ # deepcopy the LDF here to ensure that sample(rng, ldf, ...) is
173
+ # reproducible.
174
+ ldf = deepcopy (ldf)
147
175
# TODO : Right now, only generic checks are run. We could in principle
148
176
# specialise this to check for e.g. discrete variables with HMC
149
177
check_model && DynamicPPL. check_model (ldf. model; error_on_failure= true )
0 commit comments