Skip to content

Commit 151a429

Browse files
committed
make Hamiltonian directly an AbstractSampler
1 parent 71a8cf2 commit 151a429

File tree

4 files changed

+41
-141
lines changed

4 files changed

+41
-141
lines changed

src/mcmc/Inference.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,7 @@ getlogevidence(transitions, sampler, state) = missing
248248
function AbstractMCMC.bundle_samples(
249249
ts::Vector{<:Union{AbstractTransition,AbstractVarInfo}},
250250
model_or_ldf::Union{DynamicPPL.Model,DynamicPPL.LogDensityFunction},
251-
spl::Union{Sampler{<:InferenceAlgorithm},SampleFromPrior,RepeatSampler},
251+
spl::AbstractSampler,
252252
state,
253253
chain_type::Type{MCMCChains.Chains};
254254
save_state=false,
@@ -316,7 +316,7 @@ end
316316
function AbstractMCMC.bundle_samples(
317317
ts::Vector{<:Union{AbstractTransition,AbstractVarInfo}},
318318
model_or_ldf::Union{DynamicPPL.Model,DynamicPPL.LogDensityFunction},
319-
spl::Union{Sampler{<:InferenceAlgorithm},SampleFromPrior,RepeatSampler},
319+
spl::AbstractSampler,
320320
state,
321321
chain_type::Type{Vector{NamedTuple}};
322322
kwargs...,

src/mcmc/abstractmcmc.jl

Lines changed: 6 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,7 @@
2222
# Because this is a pain to implement all at once, we do it for one sampler at a time.
2323
# This type tells us which samplers have been 'updated' to the new interface.
2424

25-
# TODO: Eventually, we want to broaden this to InferenceAlgorithm
26-
const LDFCompatibleAlgorithm = Union{Hamiltonian}
27-
# TODO: Eventually, we want to broaden this to
28-
# Union{Sampler{<:InferenceAlgorithm},RepeatSampler}.
29-
const LDFCompatibleSampler = Union{Sampler{<:LDFCompatibleAlgorithm}}
25+
const LDFCompatibleSampler = Union{Hamiltonian}
3026

3127
"""
3228
sample(
@@ -223,54 +219,20 @@ end
223219
### Everything below this is boring boilerplate for the new interface. ###
224220
##########################################################################
225221

226-
function AbstractMCMC.sample(
227-
model::Model, alg::LDFCompatibleAlgorithm, N::Integer; kwargs...
228-
)
229-
return AbstractMCMC.sample(Random.default_rng(), model, alg, N; kwargs...)
230-
end
231-
232-
function AbstractMCMC.sample(
233-
ldf::LogDensityFunction, alg::LDFCompatibleAlgorithm, N::Integer; kwargs...
234-
)
235-
return AbstractMCMC.sample(Random.default_rng(), ldf, alg, N; kwargs...)
236-
end
237-
238-
function AbstractMCMC.sample(
239-
model::Model, spl::Sampler{<:LDFCompatibleAlgorithm}, N::Integer; kwargs...
240-
)
222+
function AbstractMCMC.sample(model::Model, spl::LDFCompatibleSampler, N::Integer; kwargs...)
241223
return AbstractMCMC.sample(Random.default_rng(), model, spl, N; kwargs...)
242224
end
243225

244226
function AbstractMCMC.sample(
245-
ldf::LogDensityFunction, spl::Sampler{<:LDFCompatibleAlgorithm}, N::Integer; kwargs...
227+
ldf::LogDensityFunction, spl::LDFCompatibleSampler, N::Integer; kwargs...
246228
)
247229
return AbstractMCMC.sample(Random.default_rng(), ldf, spl, N; kwargs...)
248230
end
249231

250-
function AbstractMCMC.sample(
251-
rng::Random.AbstractRNG,
252-
ldf::LogDensityFunction,
253-
alg::LDFCompatibleAlgorithm,
254-
N::Integer;
255-
kwargs...,
256-
)
257-
return AbstractMCMC.sample(rng, ldf, Sampler(alg), N; kwargs...)
258-
end
259-
260232
function AbstractMCMC.sample(
261233
rng::Random.AbstractRNG,
262234
model::Model,
263-
alg::LDFCompatibleAlgorithm,
264-
N::Integer;
265-
kwargs...,
266-
)
267-
return AbstractMCMC.sample(rng, model, Sampler(alg), N; kwargs...)
268-
end
269-
270-
function AbstractMCMC.sample(
271-
rng::Random.AbstractRNG,
272-
model::Model,
273-
spl::Sampler{<:LDFCompatibleAlgorithm},
235+
spl::LDFCompatibleSampler,
274236
N::Integer;
275237
check_model::Bool=true,
276238
kwargs...,
@@ -290,33 +252,7 @@ end
290252

291253
function AbstractMCMC.sample(
292254
model::Model,
293-
alg::LDFCompatibleAlgorithm,
294-
ensemble::AbstractMCMC.AbstractMCMCEnsemble,
295-
N::Integer,
296-
n_chains::Integer;
297-
kwargs...,
298-
)
299-
return AbstractMCMC.sample(
300-
Random.default_rng(), model, alg, ensemble, N, n_chains; kwargs...
301-
)
302-
end
303-
304-
function AbstractMCMC.sample(
305-
ldf::LogDensityFunction,
306-
alg::LDFCompatibleAlgorithm,
307-
ensemble::AbstractMCMC.AbstractMCMCEnsemble,
308-
N::Integer,
309-
n_chains::Integer;
310-
kwargs...,
311-
)
312-
return AbstractMCMC.sample(
313-
Random.default_rng(), ldf, alg, ensemble, N, n_chains; kwargs...
314-
)
315-
end
316-
317-
function AbstractMCMC.sample(
318-
model::Model,
319-
spl::Sampler{<:LDFCompatibleAlgorithm},
255+
spl::LDFCompatibleSampler,
320256
ensemble::AbstractMCMC.AbstractMCMCEnsemble,
321257
N::Integer,
322258
n_chains::Integer;
@@ -329,7 +265,7 @@ end
329265

330266
function AbstractMCMC.sample(
331267
ldf::LogDensityFunction,
332-
spl::Sampler{<:LDFCompatibleAlgorithm},
268+
spl::LDFCompatibleSampler,
333269
ensemble::AbstractMCMC.AbstractMCMCEnsemble,
334270
N::Integer,
335271
n_chains::Integer;
@@ -340,30 +276,6 @@ function AbstractMCMC.sample(
340276
)
341277
end
342278

343-
function AbstractMCMC.sample(
344-
rng::Random.AbstractRNG,
345-
ldf::LogDensityFunction,
346-
alg::LDFCompatibleAlgorithm,
347-
ensemble::AbstractMCMC.AbstractMCMCEnsemble,
348-
N::Integer,
349-
n_chains::Integer;
350-
kwargs...,
351-
)
352-
return AbstractMCMC.sample(rng, ldf, Sampler(alg), ensemble, N, n_chains; kwargs...)
353-
end
354-
355-
function AbstractMCMC.sample(
356-
rng::Random.AbstractRNG,
357-
model::Model,
358-
alg::LDFCompatibleAlgorithm,
359-
ensemble::AbstractMCMC.AbstractMCMCEnsemble,
360-
N::Integer,
361-
n_chains::Integer;
362-
kwargs...,
363-
)
364-
return AbstractMCMC.sample(rng, model, Sampler(alg), ensemble, N, n_chains; kwargs...)
365-
end
366-
367279
function AbstractMCMC.sample(
368280
rng::Random.AbstractRNG,
369281
model::Model,

src/mcmc/algorithm.jl

Lines changed: 13 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
# TODO(penelopeysm): remove
12
"""
23
InferenceAlgorithm
34
@@ -16,41 +17,30 @@ DynamicPPL.default_chain_type(sampler::Sampler{<:InferenceAlgorithm}) = MCMCChai
1617

1718
"""
1819
update_sample_kwargs(spl::AbstractSampler, N::Integer, kwargs)
19-
update_sample_kwargs(spl::InferenceAlgorithm, N::Integer, kwargs)
20-
Some InferenceAlgorithm implementations carry additional information about
21-
the keyword arguments that should be passed to `AbstractMCMC.sample`. This
22-
function provides a hook for them to update the default keyword arguments.
23-
The default implementation is for no changes to be made to `kwargs`.
24-
"""
25-
function update_sample_kwargs(spl::Sampler{<:InferenceAlgorithm}, N::Integer, kwargs)
26-
return update_sample_kwargs(spl.alg, N, kwargs)
27-
end
20+
21+
Some samplers carry additional information about the keyword arguments that
22+
should be passed to `AbstractMCMC.sample`. This function provides a hook for
23+
them to update the default keyword arguments. The default implementation is for
24+
no changes to be made to `kwargs`.
25+
"""
2826
update_sample_kwargs(::AbstractSampler, N::Integer, kwargs) = kwargs
29-
update_sample_kwargs(::InferenceAlgorithm, N::Integer, kwargs) = kwargs
3027

3128
"""
3229
get_adtype(spl::AbstractSampler)
33-
get_adtype(spl::InferenceAlgorithm)
34-
Return the automatic differentiation (AD) backend to use for the sampler.
35-
This is needed for constructing a LogDensityFunction.
36-
By default, returns nothing, i.e. the LogDensityFunction that is constructed
37-
will not know how to calculate its gradients.
38-
If the sampler or algorithm requires gradient information, then this function
30+
31+
Return the automatic differentiation (AD) backend to use for the sampler. This
32+
is needed for constructing a LogDensityFunction. By default, returns nothing,
33+
i.e. the LogDensityFunction that is constructed will not know how to calculate
34+
its gradients. If the sampler requires gradient information, then this function
3935
must return an `ADTypes.AbstractADType`.
4036
"""
4137
get_adtype(::AbstractSampler) = nothing
42-
get_adtype(::InferenceAlgorithm) = nothing
43-
get_adtype(spl::Sampler{<:InferenceAlgorithm}) = get_adtype(spl.alg)
4438

4539
"""
4640
requires_unconstrained_space(sampler::AbstractSampler)
47-
requires_unconstrained_space(sampler::InferenceAlgorithm)
41+
4842
Return `true` if the sampler / algorithm requires unconstrained space, and
4943
`false` otherwise. This is used to determine whether the initial VarInfo
5044
should be linked. Defaults to true.
5145
"""
5246
requires_unconstrained_space(::AbstractSampler) = true
53-
requires_unconstrained_space(::InferenceAlgorithm) = true
54-
function requires_unconstrained_space(spl::Sampler{<:InferenceAlgorithm})
55-
return requires_unconstrained_space(spl.alg)
56-
end

src/mcmc/hmc.jl

Lines changed: 20 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
1-
# InferenceAlgorithm interface
1+
# AbstractSampler interface for Turing
22

3-
abstract type Hamiltonian <: InferenceAlgorithm end
3+
abstract type Hamiltonian <: AbstractMCMC.AbstractSampler end
44

5-
DynamicPPL.initialsampler(::Sampler{<:Hamiltonian}) = SampleFromUniform()
5+
DynamicPPL.initialsampler(::Hamiltonian) = DynamicPPL.SampleFromUniform()
66
requires_unconstrained_space(::Hamiltonian) = true
77
# TODO(penelopeysm): This is really quite dangerous code because it implicitly
88
# assumes that any concrete type that subtypes `Hamiltonian` has an adtype
@@ -152,7 +152,7 @@ end
152152
function AbstractMCMC.step(
153153
rng::AbstractRNG,
154154
ldf::LogDensityFunction,
155-
spl::Sampler{<:Hamiltonian};
155+
spl::Hamiltonian;
156156
initial_params=nothing,
157157
nadapts=0,
158158
kwargs...,
@@ -165,7 +165,7 @@ function AbstractMCMC.step(
165165
has_initial_params = initial_params !== nothing
166166

167167
# Create a Hamiltonian.
168-
metricT = getmetricT(spl.alg)
168+
metricT = getmetricT(spl)
169169
metric = metricT(length(theta))
170170
lp_func = Base.Fix1(LogDensityProblems.logdensity, ldf)
171171
lp_grad_func = Base.Fix1(LogDensityProblems.logdensity_and_gradient, ldf)
@@ -184,23 +184,23 @@ function AbstractMCMC.step(
184184
log_density_old = getlogp(vi)
185185

186186
# Find good eps if not provided one
187-
if iszero(spl.alg.ϵ)
187+
if iszero(spl.ϵ)
188188
ϵ = AHMC.find_good_stepsize(rng, hamiltonian, theta)
189189
@info "Found initial step size" ϵ
190190
else
191-
ϵ = spl.alg.ϵ
191+
ϵ = spl.ϵ
192192
end
193193

194194
# Generate a kernel.
195-
kernel = make_ahmc_kernel(spl.alg, ϵ)
195+
kernel = make_ahmc_kernel(spl, ϵ)
196196

197197
# Create initial transition and state.
198198
# Already perform one step since otherwise we don't get any statistics.
199199
t = AHMC.transition(rng, hamiltonian, kernel, z)
200200

201201
# Adaptation
202-
adaptor = AHMCAdaptor(spl.alg, hamiltonian.metric; ϵ=ϵ)
203-
if spl.alg isa AdaptiveHamiltonian
202+
adaptor = AHMCAdaptor(spl, hamiltonian.metric; ϵ=ϵ)
203+
if spl isa AdaptiveHamiltonian
204204
hamiltonian, kernel, _ = AHMC.adapt!(
205205
hamiltonian, kernel, adaptor, 1, nadapts, t.z.θ, t.stat.acceptance_rate
206206
)
@@ -224,7 +224,7 @@ end
224224
function AbstractMCMC.step(
225225
rng::Random.AbstractRNG,
226226
ldf::LogDensityFunction,
227-
spl::Sampler{<:Hamiltonian},
227+
spl::Hamiltonian,
228228
state::HMCState;
229229
nadapts=0,
230230
kwargs...,
@@ -236,7 +236,7 @@ function AbstractMCMC.step(
236236

237237
# Adaptation
238238
i = state.i + 1
239-
if spl.alg isa AdaptiveHamiltonian
239+
if spl isa AdaptiveHamiltonian
240240
hamiltonian, kernel, _ = AHMC.adapt!(
241241
hamiltonian,
242242
state.kernel,
@@ -276,7 +276,7 @@ function get_hamiltonian(model, spl, vi, state, n)
276276
# using leafcontext(model.context) so could we just remove the argument
277277
# entirely?)
278278
DynamicPPL.SamplingContext(spl, DynamicPPL.leafcontext(model.context));
279-
adtype=spl.alg.adtype,
279+
adtype=spl.adtype,
280280
)
281281
lp_func = Base.Fix1(LogDensityProblems.logdensity, ldf)
282282
lp_grad_func = Base.Fix1(LogDensityProblems.logdensity_and_gradient, ldf)
@@ -441,17 +441,17 @@ getmetricT(::NUTS{<:Any,metricT}) where {metricT} = metricT
441441
##### HMC core functions
442442
#####
443443

444-
getstepsize(sampler::Sampler{<:Hamiltonian}, state) = sampler.alg.ϵ
445-
getstepsize(sampler::Sampler{<:AdaptiveHamiltonian}, state) = AHMC.getϵ(state.adaptor)
444+
getstepsize(sampler::Hamiltonian, state) = sampler.ϵ
445+
getstepsize(sampler::AdaptiveHamiltonian, state) = AHMC.getϵ(state.adaptor)
446446
function getstepsize(
447-
sampler::Sampler{<:AdaptiveHamiltonian},
447+
sampler::AdaptiveHamiltonian,
448448
state::HMCState{TV,TKernel,THam,PhType,AHMC.Adaptation.NoAdaptation},
449449
) where {TV,TKernel,THam,PhType}
450450
return state.kernel.τ.integrator.ϵ
451451
end
452452

453-
gen_metric(dim::Int, spl::Sampler{<:Hamiltonian}, state) = AHMC.UnitEuclideanMetric(dim)
454-
function gen_metric(dim::Int, spl::Sampler{<:AdaptiveHamiltonian}, state)
453+
gen_metric(dim::Int, spl::Hamiltonian, state) = AHMC.UnitEuclideanMetric(dim)
454+
function gen_metric(dim::Int, spl::AdaptiveHamiltonian, state)
455455
return AHMC.renew(state.hamiltonian.metric, AHMC.getM⁻¹(state.adaptor.pc))
456456
end
457457

@@ -476,13 +476,11 @@ end
476476
####
477477
#### Compiler interface, i.e. tilde operators.
478478
####
479-
function DynamicPPL.assume(
480-
rng, ::Sampler{<:Hamiltonian}, dist::Distribution, vn::VarName, vi
481-
)
479+
function DynamicPPL.assume(rng, ::Hamiltonian, dist::Distribution, vn::VarName, vi)
482480
return DynamicPPL.assume(dist, vn, vi)
483481
end
484482

485-
function DynamicPPL.observe(::Sampler{<:Hamiltonian}, d::Distribution, value, vi)
483+
function DynamicPPL.observe(::Hamiltonian, d::Distribution, value, vi)
486484
return DynamicPPL.observe(d, value, vi)
487485
end
488486

0 commit comments

Comments
 (0)