@@ -244,7 +244,7 @@ getlogevidence(transitions, sampler, state) = missing
244
244
# This is type piracy (at least for SampleFromPrior).
245
245
function AbstractMCMC. bundle_samples (
246
246
ts:: Vector{<:Union{AbstractTransition,AbstractVarInfo}} ,
247
- model :: AbstractModel ,
247
+ model_or_ldf :: Union{DynamicPPL.Model,DynamicPPL.LogDensityFunction} ,
248
248
spl:: Union{Sampler{<:InferenceAlgorithm},SampleFromPrior,RepeatSampler} ,
249
249
state,
250
250
chain_type:: Type{MCMCChains.Chains} ;
@@ -256,6 +256,11 @@ function AbstractMCMC.bundle_samples(
256
256
thinning= 1 ,
257
257
kwargs... ,
258
258
)
259
+ model = if model_or_ldf isa DynamicPPL. LogDensityFunction
260
+ model_or_ldf. model
261
+ else
262
+ model_or_ldf
263
+ end
259
264
# Convert transitions to array format.
260
265
# Also retrieve the variable names.
261
266
varnames, vals = _params_to_array (model, ts)
@@ -307,12 +312,17 @@ end
307
312
# This is type piracy (for SampleFromPrior).
308
313
function AbstractMCMC. bundle_samples (
309
314
ts:: Vector{<:Union{AbstractTransition,AbstractVarInfo}} ,
310
- model :: AbstractModel ,
315
+ model_or_ldf :: Union{DynamicPPL.Model,DynamicPPL.LogDensityFunction} ,
311
316
spl:: Union{Sampler{<:InferenceAlgorithm},SampleFromPrior,RepeatSampler} ,
312
317
state,
313
318
chain_type:: Type{Vector{NamedTuple}} ;
314
319
kwargs... ,
315
320
)
321
+ model = if model_or_ldf isa DynamicPPL. LogDensityFunction
322
+ model_or_ldf. model
323
+ else
324
+ model_or_ldf
325
+ end
316
326
return map (ts) do t
317
327
# Construct a dictionary of pairs `vn => value`.
318
328
params = OrderedDict (getparams (model, t))
0 commit comments