Skip to content

Commit 27fdd96

Browse files
committed
Add new interface for AbstractMCMC.sample with LDFCompatibleAlgorithm
1 parent 55cdaee commit 27fdd96

File tree

2 files changed

+388
-5
lines changed

2 files changed

+388
-5
lines changed

src/mcmc/Inference.jl

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -244,7 +244,7 @@ getlogevidence(transitions, sampler, state) = missing
244244
# This is type piracy (at least for SampleFromPrior).
245245
function AbstractMCMC.bundle_samples(
246246
ts::Vector{<:Union{AbstractTransition,AbstractVarInfo}},
247-
model::AbstractModel,
247+
model_or_ldf::Union{DynamicPPL.Model,DynamicPPL.LogDensityFunction},
248248
spl::Union{Sampler{<:InferenceAlgorithm},SampleFromPrior,RepeatSampler},
249249
state,
250250
chain_type::Type{MCMCChains.Chains};
@@ -256,6 +256,11 @@ function AbstractMCMC.bundle_samples(
256256
thinning=1,
257257
kwargs...,
258258
)
259+
model = if model_or_ldf isa DynamicPPL.LogDensityFunction
260+
model_or_ldf.model
261+
else
262+
model_or_ldf
263+
end
259264
# Convert transitions to array format.
260265
# Also retrieve the variable names.
261266
varnames, vals = _params_to_array(model, ts)
@@ -307,12 +312,17 @@ end
307312
# This is type piracy (for SampleFromPrior).
308313
function AbstractMCMC.bundle_samples(
309314
ts::Vector{<:Union{AbstractTransition,AbstractVarInfo}},
310-
model::AbstractModel,
315+
model_or_ldf::Union{DynamicPPL.Model,DynamicPPL.LogDensityFunction},
311316
spl::Union{Sampler{<:InferenceAlgorithm},SampleFromPrior,RepeatSampler},
312317
state,
313318
chain_type::Type{Vector{NamedTuple}};
314319
kwargs...,
315320
)
321+
model = if model_or_ldf isa DynamicPPL.LogDensityFunction
322+
model_or_ldf.model
323+
else
324+
model_or_ldf
325+
end
316326
return map(ts) do t
317327
# Construct a dictionary of pairs `vn => value`.
318328
params = OrderedDict(getparams(model, t))

0 commit comments

Comments
 (0)