@@ -104,7 +104,7 @@ mean(chain)
104
104
```
105
105
106
106
"""
107
- struct MH{P} <: InferenceAlgorithm
107
+ struct MH{P} <: AbstractSampler
108
108
proposals:: P
109
109
110
110
function MH (proposals... )
@@ -139,18 +139,26 @@ struct MH{P} <: InferenceAlgorithm
139
139
end
140
140
end
141
141
142
- # Some of the proposals require working in unconstrained space.
143
- transform_maybe (proposal:: AMH.Proposal ) = proposal
144
- function transform_maybe (proposal:: AMH.RandomWalkProposal )
145
- return AMH. RandomWalkProposal (Bijectors. transformed (proposal. proposal))
146
- end
147
-
148
- function MH (model:: Model ; proposal_type= AMH. StaticProposal)
149
- priors = DynamicPPL. extract_priors (model)
150
- props = Tuple ([proposal_type (prop) for prop in values (priors)])
151
- vars = Tuple (map (Symbol, collect (keys (priors))))
152
- priors = map (transform_maybe, NamedTuple {vars} (props))
153
- return AMH. MetropolisHastings (priors)
142
+ # Turing sampler interface
143
+ DynamicPPL. initialsampler (:: MH ) = DynamicPPL. SampleFromPrior ()
144
+ get_adtype (:: MH ) = nothing
145
+ update_sample_kwargs (:: MH , :: Integer , kwargs) = kwargs
146
+ requires_unconstrained_space (:: MH ) = false
147
+ requires_unconstrained_space (:: MH{<:AdvancedMH.RandomWalkProposal} ) = true
148
+ # `NamedTuple` of proposals
149
+ @generated function requires_unconstrained_space (
150
+ :: MH{<:NamedTuple{names,props}}
151
+ ) where {names,props}
152
+ # If we have a `NamedTuple` with proposals, we need to check whether any of
153
+ # them are `AdvancedMH.RandomWalkProposal`. If so, we need to link.
154
+ for prop in props. parameters
155
+ if prop <: AdvancedMH.RandomWalkProposal
156
+ return :(true )
157
+ end
158
+ end
159
+ # If we don't have any `AdvancedMH.RandomWalkProposal` (or if we have an
160
+ # empty `NamedTuple`), we don't need to link.
161
+ return :(false )
154
162
end
155
163
156
164
# ####################
@@ -188,7 +196,7 @@ A log density function for the MH sampler.
188
196
189
197
This variant uses the `set_namedtuple!` function to update the `VarInfo`.
190
198
"""
191
- const MHLogDensityFunction{M<: Model ,S<: Sampler{<:MH} ,V<: AbstractVarInfo } =
199
+ const MHLogDensityFunction{M<: Model ,S<: MH ,V<: AbstractVarInfo } =
192
200
DynamicPPL. LogDensityFunction{M,V,<: DynamicPPL.SamplingContext{<:S} ,AD} where {AD}
193
201
194
202
function LogDensityProblems. logdensity (f:: MHLogDensityFunction , x:: NamedTuple )
@@ -219,16 +227,16 @@ function reconstruct(dist::AbstractVector{<:MultivariateDistribution}, val::Abst
219
227
end
220
228
221
229
"""
222
- dist_val_tuple(spl::Sampler{<:MH} , vi::VarInfo)
230
+ dist_val_tuple(spl::MH , vi::VarInfo)
223
231
224
232
Return two `NamedTuples`.
225
233
226
234
The first `NamedTuple` has symbols as keys and distributions as values.
227
235
The second `NamedTuple` has model symbols as keys and their stored values as values.
228
236
"""
229
- function dist_val_tuple (spl:: Sampler{<:MH} , vi:: DynamicPPL.VarInfoOrThreadSafeVarInfo )
237
+ function dist_val_tuple (spl:: MH , vi:: DynamicPPL.VarInfoOrThreadSafeVarInfo )
230
238
vns = all_varnames_grouped_by_symbol (vi)
231
- dt = _dist_tuple (spl. alg . proposals, vi, vns)
239
+ dt = _dist_tuple (spl. proposals, vi, vns)
232
240
vt = _val_tuple (vi, vns)
233
241
return dt, vt
234
242
end
@@ -270,34 +278,25 @@ _val_tuple(::VarInfo, ::Tuple{}) = ()
270
278
end
271
279
_dist_tuple (:: @NamedTuple {}, :: VarInfo , :: Tuple{} ) = ()
272
280
273
- # Utility functions to link
274
- should_link (varinfo, sampler, proposal) = false
275
- function should_link (varinfo, sampler, proposal:: NamedTuple{(),Tuple{}} )
281
+ should_link (varinfo, proposals) = false
282
+ function should_link (varinfo, proposals:: NamedTuple{(),Tuple{}} )
276
283
# If it's an empty `NamedTuple`, we're using the priors as proposals
277
284
# in which case we shouldn't link.
278
285
return false
279
286
end
280
- function should_link (varinfo, sampler, proposal :: AdvancedMH.RandomWalkProposal )
287
+ function should_link (varinfo, proposals :: AdvancedMH.RandomWalkProposal )
281
288
return true
282
289
end
283
290
# FIXME : This won't be hit unless `vals` are all the exactly same concrete type of `AdvancedMH.RandomWalkProposal`!
284
291
function should_link (
285
- varinfo, sampler, proposal :: NamedTuple{names,vals}
292
+ varinfo, proposals :: NamedTuple{names,vals}
286
293
) where {names,vals<: NTuple{<:Any,<:AdvancedMH.RandomWalkProposal} }
287
294
return true
288
295
end
289
296
290
- function maybe_link!! (varinfo, sampler, proposal, model)
291
- return if should_link (varinfo, sampler, proposal)
292
- DynamicPPL. link!! (varinfo, model)
293
- else
294
- varinfo
295
- end
296
- end
297
-
298
297
# Make a proposal if we don't have a covariance proposal matrix (the default).
299
298
function propose!! (
300
- rng:: AbstractRNG , vi:: AbstractVarInfo , model :: Model , spl:: Sampler{<:MH} , proposal
299
+ rng:: AbstractRNG , vi:: AbstractVarInfo , ldf :: LogDensityFunction , spl:: MH , proposal
301
300
)
302
301
# Retrieve distribution and value NamedTuples.
303
302
dt, vt = dist_val_tuple (spl, vi)
@@ -307,16 +306,7 @@ function propose!!(
307
306
prev_trans = AMH. Transition (vt, getlogp (vi), false )
308
307
309
308
# Make a new transition.
310
- densitymodel = AMH. DensityModel (
311
- Base. Fix1 (
312
- LogDensityProblems. logdensity,
313
- DynamicPPL. LogDensityFunction (
314
- model,
315
- vi,
316
- DynamicPPL. SamplingContext (rng, spl, DynamicPPL. leafcontext (model. context)),
317
- ),
318
- ),
319
- )
309
+ densitymodel = AMH. DensityModel (Base. Fix1 (LogDensityProblems. logdensity, ldf))
320
310
trans, _ = AbstractMCMC. step (rng, densitymodel, mh_sampler, prev_trans)
321
311
322
312
# TODO : Make this compatible with immutable `VarInfo`.
@@ -329,70 +319,47 @@ end
329
319
function propose!! (
330
320
rng:: AbstractRNG ,
331
321
vi:: AbstractVarInfo ,
332
- model :: Model ,
333
- spl:: Sampler{<:MH} ,
322
+ ldf :: LogDensityFunction ,
323
+ spl:: MH ,
334
324
proposal:: AdvancedMH.RandomWalkProposal ,
335
325
)
336
326
# If this is the case, we can just draw directly from the proposal
337
327
# matrix.
338
328
vals = vi[:]
339
329
340
330
# Create a sampler and the previous transition.
341
- mh_sampler = AMH. MetropolisHastings (spl. alg . proposals)
331
+ mh_sampler = AMH. MetropolisHastings (spl. proposals)
342
332
prev_trans = AMH. Transition (vals, getlogp (vi), false )
343
333
344
334
# Make a new transition.
345
- densitymodel = AMH. DensityModel (
346
- Base. Fix1 (
347
- LogDensityProblems. logdensity,
348
- DynamicPPL. LogDensityFunction (
349
- model,
350
- vi,
351
- DynamicPPL. SamplingContext (rng, spl, DynamicPPL. leafcontext (model. context)),
352
- ),
353
- ),
354
- )
335
+ densitymodel = AMH. DensityModel (Base. Fix1 (LogDensityProblems. logdensity, ldf))
355
336
trans, _ = AbstractMCMC. step (rng, densitymodel, mh_sampler, prev_trans)
356
337
357
338
return setlogp!! (DynamicPPL. unflatten (vi, trans. params), trans. lp)
358
339
end
359
340
360
- function DynamicPPL. initialstep (
361
- rng:: AbstractRNG ,
362
- model:: AbstractModel ,
363
- spl:: Sampler{<:MH} ,
364
- vi:: AbstractVarInfo ;
365
- kwargs... ,
366
- )
367
- # If we're doing random walk with a covariance matrix,
368
- # just link everything before sampling.
369
- vi = maybe_link!! (vi, spl, spl. alg. proposals, model)
370
-
371
- return Transition (model, vi), vi
341
+ function AbstractMCMC. step (rng:: AbstractRNG , ldf:: LogDensityFunction , spl:: MH ; kwargs... )
342
+ vi = ldf. varinfo
343
+ return Transition (ldf. model, vi), vi
372
344
end
373
345
374
346
function AbstractMCMC. step (
375
- rng:: AbstractRNG , model :: Model , spl:: Sampler{<:MH} , vi:: AbstractVarInfo ; kwargs...
347
+ rng:: AbstractRNG , ldf :: LogDensityFunction , spl:: MH , vi:: AbstractVarInfo ; kwargs...
376
348
)
377
- # Cases:
378
- # 1. A covariance proposal matrix
379
- # 2. A bunch of NamedTuples that specify the proposal space
380
- vi = propose!! (rng, vi, model, spl, spl. alg. proposals)
381
-
382
- return Transition (model, vi), vi
349
+ vi = propose!! (rng, vi, ldf, spl, spl. proposals)
350
+ return Transition (ldf. model, vi), vi
383
351
end
384
352
385
353
# ###
386
354
# ### Compiler interface, i.e. tilde operators.
387
355
# ###
388
356
function DynamicPPL. assume (
389
- rng:: Random.AbstractRNG , spl :: Sampler{<:MH} , dist:: Distribution , vn:: VarName , vi
357
+ rng:: Random.AbstractRNG , :: MH , dist:: Distribution , vn:: VarName , vi
390
358
)
391
359
# Just defer to `SampleFromPrior`.
392
- retval = DynamicPPL. assume (rng, SampleFromPrior (), dist, vn, vi)
393
- return retval
360
+ return DynamicPPL. assume (rng, SampleFromPrior (), dist, vn, vi)
394
361
end
395
362
396
- function DynamicPPL. observe (spl :: Sampler{<:MH} , d:: Distribution , value, vi)
363
+ function DynamicPPL. observe (:: MH , d:: Distribution , value, vi)
397
364
return DynamicPPL. observe (SampleFromPrior (), d, value, vi)
398
365
end
0 commit comments