Skip to content

Commit 23afb54

Browse files
committed
Add new AbstractMCMC interface
1 parent 9d4dbf3 commit 23afb54

File tree

1 file changed

+318
-3
lines changed

1 file changed

+318
-3
lines changed

src/mcmc/abstractmcmc.jl

Lines changed: 318 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,321 @@
1-
#########################################
2-
# Default definitions for the interface #
3-
#########################################
1+
# This file contains the basic methods for `AbstractMCMC.sample`.
2+
# The overall aim is that users can call
3+
#
4+
# sample(::Model, ::InferenceAlgorithm, N)
5+
#
6+
# and have it be (eventually) forwarded to
7+
#
8+
# sample(::LogDensityFunction, ::Sampler{InferenceAlgorithm}, N)
9+
#
10+
# The former method is more convenient for most users, and has been the 'default'
11+
# API in Turing. The latter method is what really needs to be used under the hood,
12+
# because a Model on its own does not fully specify how the log-density should be
13+
# evaluated (only a LogDensityFunction has that information). The methods defined
14+
# in this file provide the 'bridge' between these two, and also provide hooks to
15+
# allow for some special behaviour, e.g. setting the default chain type to
16+
# MCMCChains.Chains, and also checking the model with DynamicPPL.check_model.
17+
#
18+
# Advanced users who want to customise the way their model is executed (e.g. by
19+
# using different types of VarInfo) can construct their own LogDensityFunction
20+
# and call `sample(ldf, spl, N)` themselves.
21+
22+
# Because this is a pain to implement all at once, we do it for one sampler at a time.
23+
# This type tells us which samplers have been 'updated' to the new interface.
24+
25+
# TODO: Eventually, we want to broaden this to InferenceAlgorithm
26+
const LDFCompatibleAlgorithm = Union{Hamiltonian}
27+
# TODO: Eventually, we want to broaden this to
28+
# Union{Sampler{<:InferenceAlgorithm},RepeatSampler}.
29+
const LDFCompatibleSampler = Union{Sampler{<:LDFCompatibleAlgorithm}}
30+
31+
# The main method: without ensemble sampling
32+
# NOTE: When updating this method, please make sure to also update the
33+
# corresponding one with ensemble sampling, right below it.
34+
function AbstractMCMC.sample(
35+
rng::Random.AbstractRNG,
36+
ldf::LogDensityFunction,
37+
spl::LDFCompatibleSampler,
38+
N::Integer;
39+
check_model::Bool=true,
40+
chain_type=MCMCChains.Chains,
41+
progress=PROGRESS[],
42+
resume_from=nothing,
43+
initial_state=DynamicPPL.loadstate(resume_from),
44+
kwargs...,
45+
)
46+
# TODO: Right now, only generic checks are run. We could in principle
47+
# specialise this to check for e.g. discrete variables with HMC
48+
check_model && DynamicPPL.check_model(ldf.model; error_on_failure=true)
49+
# Some samplers need to update the kwargs with additional information,
50+
# e.g. HMC.
51+
new_kwargs = update_sample_kwargs(spl, N, kwargs)
52+
# Forward to the main sampling function
53+
return AbstractMCMC.mcmcsample(
54+
rng,
55+
ldf,
56+
spl,
57+
N;
58+
initial_state=initial_state,
59+
chain_type=chain_type,
60+
progress=progress,
61+
new_kwargs...,
62+
)
63+
end
64+
65+
# The main method: with ensemble sampling
66+
# NOTE: When updating this method, please make sure to also update the
67+
# corresponding one without ensemble sampling, right above it.
68+
function AbstractMCMC.sample(
69+
rng::Random.AbstractRNG,
70+
ldf::LogDensityFunction,
71+
spl::LDFCompatibleSampler,
72+
ensemble::AbstractMCMC.AbstractMCMCEnsemble,
73+
N::Integer,
74+
n_chains::Integer;
75+
check_model::Bool=true,
76+
chain_type=MCMCChains.Chains,
77+
progress=PROGRESS[],
78+
resume_from=nothing,
79+
initial_state=DynamicPPL.loadstate(resume_from),
80+
kwargs...,
81+
)
82+
# TODO: Right now, only generic checks are run. We could in principle
83+
# specialise this to check for e.g. discrete variables with HMC
84+
check_model && DynamicPPL.check_model(ldf.model; error_on_failure=true)
85+
# Some samplers need to update the kwargs with additional information,
86+
# e.g. HMC.
87+
new_kwargs = update_sample_kwargs(spl, N, kwargs)
88+
# Forward to the main sampling function
89+
return AbstractMCMC.mcmcsample(
90+
rng,
91+
ldf,
92+
spl,
93+
ensemble,
94+
N,
95+
n_chains;
96+
initial_state=initial_state,
97+
chain_type=chain_type,
98+
progress=progress,
99+
new_kwargs...,
100+
)
101+
end
102+
103+
# This method should be in DynamicPPL. We will move it there when all the
104+
# Turing samplers have been updated.
105+
"""
106+
initialise_varinfo(rng, model, sampler, initial_params=nothing, link=false)
107+
108+
Return a suitable initial varinfo object, which will be used when sampling
109+
`model` with `sampler`. If given, the initial parameter values will be set in
110+
the varinfo object. Also performs linking if requested.
111+
112+
# Arguments
113+
- `rng::Random.AbstractRNG`: Random number generator.
114+
- `model::Model`: Model for which we want to create a varinfo object.
115+
- `sampler::AbstractSampler`: Sampler which will make use of the varinfo object.
116+
- `initial_params::Union{AbstractVector,Nothing}`: Initial parameter values to
117+
be set in the varinfo object. Note that these should be given in unconstrained
118+
space.
119+
- `link::Bool`: Whether to link the varinfo.
120+
121+
# Returns
122+
- `AbstractVarInfo`: Default varinfo object for the given `model` and `sampler`.
123+
"""
124+
function initialise_varinfo(
125+
rng::Random.AbstractRNG,
126+
model::Model,
127+
sampler::LDFCompatibleSampler,
128+
initial_params::Union{AbstractVector,Nothing}=nothing,
129+
# We could set `link=requires_unconstrained_space(sampler)`, but that would
130+
# preclude moving `initialise_varinfo` to DynamicPPL, since
131+
# `requires_unconstrained_space` is defined in Turing (unless that function
132+
# is also moved to DynamicPPL, or AbstractMCMC)
133+
link::Bool=false,
134+
)
135+
init_sampler = DynamicPPL.initialsampler(sampler)
136+
vi = DynamicPPL.typed_varinfo(rng, model, init_sampler)
137+
138+
# Update the parameters if provided.
139+
if initial_params !== nothing
140+
# Note that initialize_parameters!! expects parameters in to be
141+
# specified in unconstrained space. TODO: Make this more generic.
142+
vi = DynamicPPL.initialize_parameters!!(vi, initial_params, model)
143+
# Update joint log probability.
144+
# This is a quick fix for https://github.com/TuringLang/Turing.jl/issues/1588
145+
# and https://github.com/TuringLang/Turing.jl/issues/1563
146+
# to avoid that existing variables are resampled
147+
vi = last(DynamicPPL.evaluate!!(model, vi, DynamicPPL.DefaultContext()))
148+
end
149+
150+
return if link
151+
DynamicPPL.link(vi, model)
152+
else
153+
vi
154+
end
155+
end
156+
157+
##########################################################################
158+
### Everything below this is boring boilerplate for the new interface. ###
159+
##########################################################################
160+
161+
function AbstractMCMC.sample(
162+
model::Model, alg::LDFCompatibleAlgorithm, N::Integer; kwargs...
163+
)
164+
return AbstractMCMC.sample(Random.default_rng(), model, alg, N; kwargs...)
165+
end
166+
167+
function AbstractMCMC.sample(
168+
ldf::LogDensityFunction, alg::LDFCompatibleAlgorithm, N::Integer; kwargs...
169+
)
170+
return AbstractMCMC.sample(Random.default_rng(), ldf, alg, N; kwargs...)
171+
end
172+
173+
function AbstractMCMC.sample(
174+
model::Model, spl::Sampler{<:LDFCompatibleAlgorithm}, N::Integer; kwargs...
175+
)
176+
return AbstractMCMC.sample(Random.default_rng(), model, spl, N; kwargs...)
177+
end
178+
179+
function AbstractMCMC.sample(
180+
ldf::LogDensityFunction, spl::Sampler{<:LDFCompatibleAlgorithm}, N::Integer; kwargs...
181+
)
182+
return AbstractMCMC.sample(Random.default_rng(), ldf, spl, N; kwargs...)
183+
end
184+
185+
function AbstractMCMC.sample(
186+
rng::Random.AbstractRNG,
187+
ldf::LogDensityFunction,
188+
alg::LDFCompatibleAlgorithm,
189+
N::Integer;
190+
kwargs...,
191+
)
192+
return AbstractMCMC.sample(rng, ldf, Sampler(alg), N; kwargs...)
193+
end
194+
195+
function AbstractMCMC.sample(
196+
rng::Random.AbstractRNG,
197+
model::Model,
198+
alg::LDFCompatibleAlgorithm,
199+
N::Integer;
200+
kwargs...,
201+
)
202+
return AbstractMCMC.sample(rng, model, Sampler(alg), N; kwargs...)
203+
end
204+
205+
function AbstractMCMC.sample(
206+
rng::Random.AbstractRNG,
207+
model::Model,
208+
spl::Sampler{<:LDFCompatibleAlgorithm},
209+
N::Integer;
210+
kwargs...,
211+
)
212+
initial_params = get(kwargs, :initial_params, nothing)
213+
link = requires_unconstrained_space(spl)
214+
vi = initialise_varinfo(rng, model, spl, initial_params, link)
215+
ldf = LogDensityFunction(model, vi; adtype=get_adtype(spl))
216+
return AbstractMCMC.sample(rng, ldf, spl, N; kwargs...)
217+
end
218+
219+
function AbstractMCMC.sample(
220+
model::Model,
221+
alg::LDFCompatibleAlgorithm,
222+
ensemble::AbstractMCMC.AbstractMCMCEnsemble,
223+
N::Integer,
224+
n_chains::Integer;
225+
kwargs...,
226+
)
227+
return AbstractMCMC.sample(
228+
Random.default_rng(), model, alg, ensemble, N, n_chains; kwargs...
229+
)
230+
end
231+
232+
function AbstractMCMC.sample(
233+
ldf::LogDensityFunction,
234+
alg::LDFCompatibleAlgorithm,
235+
ensemble::AbstractMCMC.AbstractMCMCEnsemble,
236+
N::Integer,
237+
n_chains::Integer;
238+
kwargs...,
239+
)
240+
return AbstractMCMC.sample(
241+
Random.default_rng(), ldf, alg, ensemble, N, n_chains; kwargs...
242+
)
243+
end
244+
245+
function AbstractMCMC.sample(
246+
model::Model,
247+
spl::Sampler{<:LDFCompatibleAlgorithm},
248+
ensemble::AbstractMCMC.AbstractMCMCEnsemble,
249+
N::Integer,
250+
n_chains::Integer;
251+
kwargs...,
252+
)
253+
return AbstractMCMC.sample(
254+
Random.default_rng(), model, spl, ensemble, N, n_chains; kwargs...
255+
)
256+
end
257+
258+
function AbstractMCMC.sample(
259+
ldf::LogDensityFunction,
260+
spl::Sampler{<:LDFCompatibleAlgorithm},
261+
ensemble::AbstractMCMC.AbstractMCMCEnsemble,
262+
N::Integer,
263+
n_chains::Integer;
264+
kwargs...,
265+
)
266+
return AbstractMCMC.sample(
267+
Random.default_rng(), ldf, spl, ensemble, N, n_chains; kwargs...
268+
)
269+
end
270+
271+
function AbstractMCMC.sample(
272+
rng::Random.AbstractRNG,
273+
ldf::LogDensityFunction,
274+
alg::LDFCompatibleAlgorithm,
275+
ensemble::AbstractMCMC.AbstractMCMCEnsemble,
276+
N::Integer,
277+
n_chains::Integer;
278+
kwargs...,
279+
)
280+
return AbstractMCMC.sample(rng, ldf, Sampler(alg), ensemble, N, n_chains; kwargs...)
281+
end
282+
283+
function AbstractMCMC.sample(
284+
rng::Random.AbstractRNG,
285+
model::Model,
286+
alg::LDFCompatibleAlgorithm,
287+
ensemble::AbstractMCMC.AbstractMCMCEnsemble,
288+
N::Integer,
289+
n_chains::Integer;
290+
kwargs...,
291+
)
292+
return AbstractMCMC.sample(rng, model, Sampler(alg), ensemble, N, n_chains; kwargs...)
293+
end
294+
295+
function AbstractMCMC.sample(
296+
rng::Random.AbstractRNG,
297+
model::Model,
298+
spl::LDFCompatibleSampler,
299+
ensemble::AbstractMCMC.AbstractMCMCEnsemble,
300+
N::Integer,
301+
n_chains::Integer;
302+
kwargs...,
303+
)
304+
initial_params = get(kwargs, :initial_params, nothing)
305+
link = requires_unconstrained_space(spl)
306+
vi = initialise_varinfo(rng, model, spl, initial_params, link)
307+
ldf = LogDensityFunction(model, vi; adtype=get_adtype(spl))
308+
return AbstractMCMC.sample(rng, ldf, spl, ensemble, N, n_chains; kwargs...)
309+
end
310+
311+
########################################################
312+
# DEPRECATED SAMPLE METHODS #
313+
########################################################
314+
# All the code below should eventually be removed. #
315+
# We need to keep it here for now so that the #
316+
# inference algorithms that _haven't_ yet been updated #
317+
# to take LogDensityFunction still work. #
318+
########################################################
4319

5320
function AbstractMCMC.sample(
6321
model::AbstractModel, alg::InferenceAlgorithm, N::Integer; kwargs...

0 commit comments

Comments
 (0)