Skip to content

Commit b076b44

Browse files
committed
Add MH as well
1 parent 47e3c2f commit b076b44

File tree

3 files changed

+88
-78
lines changed

3 files changed

+88
-78
lines changed

src/mcmc/abstractmcmc.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
# Because this is a pain to implement all at once, we do it for one sampler at a time.
2323
# This type tells us which samplers have been 'updated' to the new interface.
24-
const LDFCompatibleSampler = Union{Hamiltonian,ESS}
24+
const LDFCompatibleSampler = Union{Hamiltonian,ESS,MH}
2525

2626
"""
2727
sample(

src/mcmc/mh.jl

Lines changed: 44 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ mean(chain)
104104
```
105105
106106
"""
107-
struct MH{P} <: InferenceAlgorithm
107+
struct MH{P} <: AbstractSampler
108108
proposals::P
109109

110110
function MH(proposals...)
@@ -139,18 +139,26 @@ struct MH{P} <: InferenceAlgorithm
139139
end
140140
end
141141

142-
# Some of the proposals require working in unconstrained space.
143-
transform_maybe(proposal::AMH.Proposal) = proposal
144-
function transform_maybe(proposal::AMH.RandomWalkProposal)
145-
return AMH.RandomWalkProposal(Bijectors.transformed(proposal.proposal))
146-
end
147-
148-
function MH(model::Model; proposal_type=AMH.StaticProposal)
149-
priors = DynamicPPL.extract_priors(model)
150-
props = Tuple([proposal_type(prop) for prop in values(priors)])
151-
vars = Tuple(map(Symbol, collect(keys(priors))))
152-
priors = map(transform_maybe, NamedTuple{vars}(props))
153-
return AMH.MetropolisHastings(priors)
142+
# Turing sampler interface
143+
DynamicPPL.initialsampler(::MH) = DynamicPPL.SampleFromPrior()
144+
get_adtype(::MH) = nothing
145+
update_sample_kwargs(::MH, ::Integer, kwargs) = kwargs
146+
requires_unconstrained_space(::MH) = false
147+
requires_unconstrained_space(::MH{<:AdvancedMH.RandomWalkProposal}) = true
148+
# `NamedTuple` of proposals
149+
@generated function requires_unconstrained_space(
150+
::MH{<:NamedTuple{names,props}}
151+
) where {names,props}
152+
# If we have a `NamedTuple` with proposals, we need to check whether any of
153+
# them are `AdvancedMH.RandomWalkProposal`. If so, we need to link.
154+
for prop in props.parameters
155+
if prop <: AdvancedMH.RandomWalkProposal
156+
return :(true)
157+
end
158+
end
159+
# If we don't have any `AdvancedMH.RandomWalkProposal` (or if we have an
160+
# empty `NamedTuple`), we don't need to link.
161+
return :(false)
154162
end
155163

156164
#####################
@@ -188,7 +196,7 @@ A log density function for the MH sampler.
188196
189197
This variant uses the `set_namedtuple!` function to update the `VarInfo`.
190198
"""
191-
const MHLogDensityFunction{M<:Model,S<:Sampler{<:MH},V<:AbstractVarInfo} =
199+
const MHLogDensityFunction{M<:Model,S<:MH,V<:AbstractVarInfo} =
192200
DynamicPPL.LogDensityFunction{M,V,<:DynamicPPL.SamplingContext{<:S},AD} where {AD}
193201

194202
function LogDensityProblems.logdensity(f::MHLogDensityFunction, x::NamedTuple)
@@ -219,16 +227,16 @@ function reconstruct(dist::AbstractVector{<:MultivariateDistribution}, val::Abst
219227
end
220228

221229
"""
222-
dist_val_tuple(spl::Sampler{<:MH}, vi::VarInfo)
230+
dist_val_tuple(spl::MH, vi::VarInfo)
223231
224232
Return two `NamedTuples`.
225233
226234
The first `NamedTuple` has symbols as keys and distributions as values.
227235
The second `NamedTuple` has model symbols as keys and their stored values as values.
228236
"""
229-
function dist_val_tuple(spl::Sampler{<:MH}, vi::DynamicPPL.VarInfoOrThreadSafeVarInfo)
237+
function dist_val_tuple(spl::MH, vi::DynamicPPL.VarInfoOrThreadSafeVarInfo)
230238
vns = all_varnames_grouped_by_symbol(vi)
231-
dt = _dist_tuple(spl.alg.proposals, vi, vns)
239+
dt = _dist_tuple(spl.proposals, vi, vns)
232240
vt = _val_tuple(vi, vns)
233241
return dt, vt
234242
end
@@ -270,34 +278,25 @@ _val_tuple(::VarInfo, ::Tuple{}) = ()
270278
end
271279
_dist_tuple(::@NamedTuple{}, ::VarInfo, ::Tuple{}) = ()
272280

273-
# Utility functions to link
274-
should_link(varinfo, sampler, proposal) = false
275-
function should_link(varinfo, sampler, proposal::NamedTuple{(),Tuple{}})
281+
should_link(varinfo, proposals) = false
282+
function should_link(varinfo, proposals::NamedTuple{(),Tuple{}})
276283
# If it's an empty `NamedTuple`, we're using the priors as proposals
277284
# in which case we shouldn't link.
278285
return false
279286
end
280-
function should_link(varinfo, sampler, proposal::AdvancedMH.RandomWalkProposal)
287+
function should_link(varinfo, proposals::AdvancedMH.RandomWalkProposal)
281288
return true
282289
end
283290
# FIXME: This won't be hit unless `vals` are all the exactly same concrete type of `AdvancedMH.RandomWalkProposal`!
284291
function should_link(
285-
varinfo, sampler, proposal::NamedTuple{names,vals}
292+
varinfo, proposals::NamedTuple{names,vals}
286293
) where {names,vals<:NTuple{<:Any,<:AdvancedMH.RandomWalkProposal}}
287294
return true
288295
end
289296

290-
function maybe_link!!(varinfo, sampler, proposal, model)
291-
return if should_link(varinfo, sampler, proposal)
292-
DynamicPPL.link!!(varinfo, model)
293-
else
294-
varinfo
295-
end
296-
end
297-
298297
# Make a proposal if we don't have a covariance proposal matrix (the default).
299298
function propose!!(
300-
rng::AbstractRNG, vi::AbstractVarInfo, model::Model, spl::Sampler{<:MH}, proposal
299+
rng::AbstractRNG, vi::AbstractVarInfo, ldf::LogDensityFunction, spl::MH, proposal
301300
)
302301
# Retrieve distribution and value NamedTuples.
303302
dt, vt = dist_val_tuple(spl, vi)
@@ -307,16 +306,7 @@ function propose!!(
307306
prev_trans = AMH.Transition(vt, getlogp(vi), false)
308307

309308
# Make a new transition.
310-
densitymodel = AMH.DensityModel(
311-
Base.Fix1(
312-
LogDensityProblems.logdensity,
313-
DynamicPPL.LogDensityFunction(
314-
model,
315-
vi,
316-
DynamicPPL.SamplingContext(rng, spl, DynamicPPL.leafcontext(model.context)),
317-
),
318-
),
319-
)
309+
densitymodel = AMH.DensityModel(Base.Fix1(LogDensityProblems.logdensity, ldf))
320310
trans, _ = AbstractMCMC.step(rng, densitymodel, mh_sampler, prev_trans)
321311

322312
# TODO: Make this compatible with immutable `VarInfo`.
@@ -329,70 +319,47 @@ end
329319
function propose!!(
330320
rng::AbstractRNG,
331321
vi::AbstractVarInfo,
332-
model::Model,
333-
spl::Sampler{<:MH},
322+
ldf::LogDensityFunction,
323+
spl::MH,
334324
proposal::AdvancedMH.RandomWalkProposal,
335325
)
336326
# If this is the case, we can just draw directly from the proposal
337327
# matrix.
338328
vals = vi[:]
339329

340330
# Create a sampler and the previous transition.
341-
mh_sampler = AMH.MetropolisHastings(spl.alg.proposals)
331+
mh_sampler = AMH.MetropolisHastings(spl.proposals)
342332
prev_trans = AMH.Transition(vals, getlogp(vi), false)
343333

344334
# Make a new transition.
345-
densitymodel = AMH.DensityModel(
346-
Base.Fix1(
347-
LogDensityProblems.logdensity,
348-
DynamicPPL.LogDensityFunction(
349-
model,
350-
vi,
351-
DynamicPPL.SamplingContext(rng, spl, DynamicPPL.leafcontext(model.context)),
352-
),
353-
),
354-
)
335+
densitymodel = AMH.DensityModel(Base.Fix1(LogDensityProblems.logdensity, ldf))
355336
trans, _ = AbstractMCMC.step(rng, densitymodel, mh_sampler, prev_trans)
356337

357338
return setlogp!!(DynamicPPL.unflatten(vi, trans.params), trans.lp)
358339
end
359340

360-
function DynamicPPL.initialstep(
361-
rng::AbstractRNG,
362-
model::AbstractModel,
363-
spl::Sampler{<:MH},
364-
vi::AbstractVarInfo;
365-
kwargs...,
366-
)
367-
# If we're doing random walk with a covariance matrix,
368-
# just link everything before sampling.
369-
vi = maybe_link!!(vi, spl, spl.alg.proposals, model)
370-
371-
return Transition(model, vi), vi
341+
function AbstractMCMC.step(rng::AbstractRNG, ldf::LogDensityFunction, spl::MH; kwargs...)
342+
vi = ldf.varinfo
343+
return Transition(ldf.model, vi), vi
372344
end
373345

374346
function AbstractMCMC.step(
375-
rng::AbstractRNG, model::Model, spl::Sampler{<:MH}, vi::AbstractVarInfo; kwargs...
347+
rng::AbstractRNG, ldf::LogDensityFunction, spl::MH, vi::AbstractVarInfo; kwargs...
376348
)
377-
# Cases:
378-
# 1. A covariance proposal matrix
379-
# 2. A bunch of NamedTuples that specify the proposal space
380-
vi = propose!!(rng, vi, model, spl, spl.alg.proposals)
381-
382-
return Transition(model, vi), vi
349+
vi = propose!!(rng, vi, ldf, spl, spl.proposals)
350+
return Transition(ldf.model, vi), vi
383351
end
384352

385353
####
386354
#### Compiler interface, i.e. tilde operators.
387355
####
388356
function DynamicPPL.assume(
389-
rng::Random.AbstractRNG, spl::Sampler{<:MH}, dist::Distribution, vn::VarName, vi
357+
rng::Random.AbstractRNG, ::MH, dist::Distribution, vn::VarName, vi
390358
)
391359
# Just defer to `SampleFromPrior`.
392-
retval = DynamicPPL.assume(rng, SampleFromPrior(), dist, vn, vi)
393-
return retval
360+
return DynamicPPL.assume(rng, SampleFromPrior(), dist, vn, vi)
394361
end
395362

396-
function DynamicPPL.observe(spl::Sampler{<:MH}, d::Distribution, value, vi)
363+
function DynamicPPL.observe(::MH, d::Distribution, value, vi)
397364
return DynamicPPL.observe(SampleFromPrior(), d, value, vi)
398365
end

test/mcmc/mh.jl

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,49 @@ GKernel(var) = (x) -> Normal(x, sqrt.(var))
2121
@info "Starting MH tests"
2222
seed = 23
2323

24+
@testset "InferenceAlgorithm interface" begin
25+
algs_and_unconstrained = [
26+
(MH(), false), # Sample from priors, no need to link
27+
(MH(:a => Normal()), false), # static proposal
28+
(MH(:a => a -> Normal(a, 1)), false), # static proposal
29+
(MH([0.25 0.05; 0.05 0.50]), true), # RWMH with covariance matrix
30+
(MH(:a => AdvancedMH.RandomWalkProposal(Normal())), true), # explicit RWMH
31+
]
32+
@testset "$alg" for (alg, unconstrained) in algs_and_unconstrained
33+
@test Turing.Inference.get_adtype(alg) === nothing
34+
@test Turing.Inference.requires_unconstrained_space(alg) == unconstrained
35+
kwargs = (; _foo="bar")
36+
@test Turing.Inference.update_sample_kwargs(alg, 1000, kwargs) == kwargs
37+
end
38+
end
39+
40+
@testset "sample() interface" begin
41+
@model function demo_normal(x)
42+
a ~ Normal()
43+
return x ~ Normal(a)
44+
end
45+
model = demo_normal(2.0)
46+
ldf = LogDensityFunction(model)
47+
sampling_objects = Dict("DynamicPPL.Model" => model, "LogDensityFunction" => ldf)
48+
seed = 468
49+
50+
@testset "sampling with $name" for (name, model_or_ldf) in sampling_objects
51+
spl = MH()
52+
# check sampling works without rng
53+
@test sample(model_or_ldf, spl, 5) isa Chains
54+
# check reproducibility with rng
55+
chn1 = sample(Random.Xoshiro(seed), model_or_ldf, spl, 5)
56+
chn2 = sample(Random.Xoshiro(seed), model_or_ldf, spl, 5)
57+
@test mean(chn1[:a]) == mean(chn2[:a])
58+
end
59+
60+
@testset "check that initial_params are respected" begin
61+
a0 = 5.0
62+
chn = sample(model, MH(), 5; initial_params=[a0])
63+
@test chn[:a][1] == a0
64+
end
65+
end
66+
2467
@testset "mh constructor" begin
2568
N = 10
2669
s1 = MH((:s, InverseGamma(2, 3)), (:m, GKernel(3.0)))

0 commit comments

Comments
 (0)