Skip to content

Commit 49f6988

Browse files
committed
make Hamiltonian directly an AbstractSampler
1 parent 85b1997 commit 49f6988

File tree

6 files changed

+46
-146
lines changed

6 files changed

+46
-146
lines changed

src/mcmc/Inference.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,7 @@ getlogevidence(transitions, sampler, state) = missing
248248
function AbstractMCMC.bundle_samples(
249249
ts::Vector{<:Union{AbstractTransition,AbstractVarInfo}},
250250
model_or_ldf::Union{DynamicPPL.Model,DynamicPPL.LogDensityFunction},
251-
spl::Union{Sampler{<:InferenceAlgorithm},SampleFromPrior,RepeatSampler},
251+
spl::AbstractSampler,
252252
state,
253253
chain_type::Type{MCMCChains.Chains};
254254
save_state=false,
@@ -316,7 +316,7 @@ end
316316
function AbstractMCMC.bundle_samples(
317317
ts::Vector{<:Union{AbstractTransition,AbstractVarInfo}},
318318
model_or_ldf::Union{DynamicPPL.Model,DynamicPPL.LogDensityFunction},
319-
spl::Union{Sampler{<:InferenceAlgorithm},SampleFromPrior,RepeatSampler},
319+
spl::AbstractSampler,
320320
state,
321321
chain_type::Type{Vector{NamedTuple}};
322322
kwargs...,

src/mcmc/abstractmcmc.jl

Lines changed: 6 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,7 @@
2222
# Because this is a pain to implement all at once, we do it for one sampler at a time.
2323
# This type tells us which samplers have been 'updated' to the new interface.
2424

25-
# TODO: Eventually, we want to broaden this to InferenceAlgorithm
26-
const LDFCompatibleAlgorithm = Union{Hamiltonian}
27-
# TODO: Eventually, we want to broaden this to
28-
# Union{Sampler{<:InferenceAlgorithm},RepeatSampler}.
29-
const LDFCompatibleSampler = Union{Sampler{<:LDFCompatibleAlgorithm}}
25+
const LDFCompatibleSampler = Union{Hamiltonian}
3026

3127
"""
3228
sample(
@@ -251,54 +247,20 @@ end
251247
### Everything below this is boring boilerplate for the new interface. ###
252248
##########################################################################
253249

254-
function AbstractMCMC.sample(
255-
model::Model, alg::LDFCompatibleAlgorithm, N::Integer; kwargs...
256-
)
257-
return AbstractMCMC.sample(Random.default_rng(), model, alg, N; kwargs...)
258-
end
259-
260-
function AbstractMCMC.sample(
261-
ldf::LogDensityFunction, alg::LDFCompatibleAlgorithm, N::Integer; kwargs...
262-
)
263-
return AbstractMCMC.sample(Random.default_rng(), ldf, alg, N; kwargs...)
264-
end
265-
266-
function AbstractMCMC.sample(
267-
model::Model, spl::Sampler{<:LDFCompatibleAlgorithm}, N::Integer; kwargs...
268-
)
250+
function AbstractMCMC.sample(model::Model, spl::LDFCompatibleSampler, N::Integer; kwargs...)
269251
return AbstractMCMC.sample(Random.default_rng(), model, spl, N; kwargs...)
270252
end
271253

272254
function AbstractMCMC.sample(
273-
ldf::LogDensityFunction, spl::Sampler{<:LDFCompatibleAlgorithm}, N::Integer; kwargs...
255+
ldf::LogDensityFunction, spl::LDFCompatibleSampler, N::Integer; kwargs...
274256
)
275257
return AbstractMCMC.sample(Random.default_rng(), ldf, spl, N; kwargs...)
276258
end
277259

278-
function AbstractMCMC.sample(
279-
rng::Random.AbstractRNG,
280-
ldf::LogDensityFunction,
281-
alg::LDFCompatibleAlgorithm,
282-
N::Integer;
283-
kwargs...,
284-
)
285-
return AbstractMCMC.sample(rng, ldf, Sampler(alg), N; kwargs...)
286-
end
287-
288260
function AbstractMCMC.sample(
289261
rng::Random.AbstractRNG,
290262
model::Model,
291-
alg::LDFCompatibleAlgorithm,
292-
N::Integer;
293-
kwargs...,
294-
)
295-
return AbstractMCMC.sample(rng, model, Sampler(alg), N; kwargs...)
296-
end
297-
298-
function AbstractMCMC.sample(
299-
rng::Random.AbstractRNG,
300-
model::Model,
301-
spl::Sampler{<:LDFCompatibleAlgorithm},
263+
spl::LDFCompatibleSampler,
302264
N::Integer;
303265
check_model::Bool=true,
304266
kwargs...,
@@ -318,33 +280,7 @@ end
318280

319281
function AbstractMCMC.sample(
320282
model::Model,
321-
alg::LDFCompatibleAlgorithm,
322-
ensemble::AbstractMCMC.AbstractMCMCEnsemble,
323-
N::Integer,
324-
n_chains::Integer;
325-
kwargs...,
326-
)
327-
return AbstractMCMC.sample(
328-
Random.default_rng(), model, alg, ensemble, N, n_chains; kwargs...
329-
)
330-
end
331-
332-
function AbstractMCMC.sample(
333-
ldf::LogDensityFunction,
334-
alg::LDFCompatibleAlgorithm,
335-
ensemble::AbstractMCMC.AbstractMCMCEnsemble,
336-
N::Integer,
337-
n_chains::Integer;
338-
kwargs...,
339-
)
340-
return AbstractMCMC.sample(
341-
Random.default_rng(), ldf, alg, ensemble, N, n_chains; kwargs...
342-
)
343-
end
344-
345-
function AbstractMCMC.sample(
346-
model::Model,
347-
spl::Sampler{<:LDFCompatibleAlgorithm},
283+
spl::LDFCompatibleSampler,
348284
ensemble::AbstractMCMC.AbstractMCMCEnsemble,
349285
N::Integer,
350286
n_chains::Integer;
@@ -357,7 +293,7 @@ end
357293

358294
function AbstractMCMC.sample(
359295
ldf::LogDensityFunction,
360-
spl::Sampler{<:LDFCompatibleAlgorithm},
296+
spl::LDFCompatibleSampler,
361297
ensemble::AbstractMCMC.AbstractMCMCEnsemble,
362298
N::Integer,
363299
n_chains::Integer;
@@ -368,30 +304,6 @@ function AbstractMCMC.sample(
368304
)
369305
end
370306

371-
function AbstractMCMC.sample(
372-
rng::Random.AbstractRNG,
373-
ldf::LogDensityFunction,
374-
alg::LDFCompatibleAlgorithm,
375-
ensemble::AbstractMCMC.AbstractMCMCEnsemble,
376-
N::Integer,
377-
n_chains::Integer;
378-
kwargs...,
379-
)
380-
return AbstractMCMC.sample(rng, ldf, Sampler(alg), ensemble, N, n_chains; kwargs...)
381-
end
382-
383-
function AbstractMCMC.sample(
384-
rng::Random.AbstractRNG,
385-
model::Model,
386-
alg::LDFCompatibleAlgorithm,
387-
ensemble::AbstractMCMC.AbstractMCMCEnsemble,
388-
N::Integer,
389-
n_chains::Integer;
390-
kwargs...,
391-
)
392-
return AbstractMCMC.sample(rng, model, Sampler(alg), ensemble, N, n_chains; kwargs...)
393-
end
394-
395307
function AbstractMCMC.sample(
396308
rng::Random.AbstractRNG,
397309
model::Model,

src/mcmc/algorithm.jl

Lines changed: 13 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
# TODO(penelopeysm): remove
12
"""
23
InferenceAlgorithm
34
@@ -16,41 +17,30 @@ DynamicPPL.default_chain_type(sampler::Sampler{<:InferenceAlgorithm}) = MCMCChai
1617

1718
"""
1819
update_sample_kwargs(spl::AbstractSampler, N::Integer, kwargs)
19-
update_sample_kwargs(spl::InferenceAlgorithm, N::Integer, kwargs)
20-
Some InferenceAlgorithm implementations carry additional information about
21-
the keyword arguments that should be passed to `AbstractMCMC.sample`. This
22-
function provides a hook for them to update the default keyword arguments.
23-
The default implementation is for no changes to be made to `kwargs`.
24-
"""
25-
function update_sample_kwargs(spl::Sampler{<:InferenceAlgorithm}, N::Integer, kwargs)
26-
return update_sample_kwargs(spl.alg, N, kwargs)
27-
end
20+
21+
Some samplers carry additional information about the keyword arguments that
22+
should be passed to `AbstractMCMC.sample`. This function provides a hook for
23+
them to update the default keyword arguments. The default implementation is for
24+
no changes to be made to `kwargs`.
25+
"""
2826
update_sample_kwargs(::AbstractSampler, N::Integer, kwargs) = kwargs
29-
update_sample_kwargs(::InferenceAlgorithm, N::Integer, kwargs) = kwargs
3027

3128
"""
3229
get_adtype(spl::AbstractSampler)
33-
get_adtype(spl::InferenceAlgorithm)
34-
Return the automatic differentiation (AD) backend to use for the sampler.
35-
This is needed for constructing a LogDensityFunction.
36-
By default, returns nothing, i.e. the LogDensityFunction that is constructed
37-
will not know how to calculate its gradients.
38-
If the sampler or algorithm requires gradient information, then this function
30+
31+
Return the automatic differentiation (AD) backend to use for the sampler. This
32+
is needed for constructing a LogDensityFunction. By default, returns nothing,
33+
i.e. the LogDensityFunction that is constructed will not know how to calculate
34+
its gradients. If the sampler requires gradient information, then this function
3935
must return an `ADTypes.AbstractADType`.
4036
"""
4137
get_adtype(::AbstractSampler) = nothing
42-
get_adtype(::InferenceAlgorithm) = nothing
43-
get_adtype(spl::Sampler{<:InferenceAlgorithm}) = get_adtype(spl.alg)
4438

4539
"""
4640
requires_unconstrained_space(sampler::AbstractSampler)
47-
requires_unconstrained_space(sampler::InferenceAlgorithm)
41+
4842
Return `true` if the sampler / algorithm requires unconstrained space, and
4943
`false` otherwise. This is used to determine whether the initial VarInfo
5044
should be linked. Defaults to true.
5145
"""
5246
requires_unconstrained_space(::AbstractSampler) = true
53-
requires_unconstrained_space(::InferenceAlgorithm) = true
54-
function requires_unconstrained_space(spl::Sampler{<:InferenceAlgorithm})
55-
return requires_unconstrained_space(spl.alg)
56-
end

src/mcmc/hmc.jl

Lines changed: 20 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
1-
# InferenceAlgorithm interface
1+
# AbstractSampler interface for Turing
22

3-
abstract type Hamiltonian <: InferenceAlgorithm end
3+
abstract type Hamiltonian <: AbstractMCMC.AbstractSampler end
44

5-
DynamicPPL.initialsampler(::Sampler{<:Hamiltonian}) = SampleFromUniform()
5+
DynamicPPL.initialsampler(::Hamiltonian) = DynamicPPL.SampleFromUniform()
66
requires_unconstrained_space(::Hamiltonian) = true
77
# TODO(penelopeysm): This is really quite dangerous code because it implicitly
88
# assumes that any concrete type that subtypes `Hamiltonian` has an adtype
@@ -152,7 +152,7 @@ end
152152
function AbstractMCMC.step(
153153
rng::AbstractRNG,
154154
ldf::LogDensityFunction,
155-
spl::Sampler{<:Hamiltonian};
155+
spl::Hamiltonian;
156156
initial_params=nothing,
157157
nadapts=0,
158158
kwargs...,
@@ -165,7 +165,7 @@ function AbstractMCMC.step(
165165
has_initial_params = initial_params !== nothing
166166

167167
# Create a Hamiltonian.
168-
metricT = getmetricT(spl.alg)
168+
metricT = getmetricT(spl)
169169
metric = metricT(length(theta))
170170
lp_func = Base.Fix1(LogDensityProblems.logdensity, ldf)
171171
lp_grad_func = Base.Fix1(LogDensityProblems.logdensity_and_gradient, ldf)
@@ -184,23 +184,23 @@ function AbstractMCMC.step(
184184
log_density_old = getlogp(vi)
185185

186186
# Find good eps if not provided one
187-
if iszero(spl.alg.ϵ)
187+
if iszero(spl.ϵ)
188188
ϵ = AHMC.find_good_stepsize(rng, hamiltonian, theta)
189189
@info "Found initial step size" ϵ
190190
else
191-
ϵ = spl.alg.ϵ
191+
ϵ = spl.ϵ
192192
end
193193

194194
# Generate a kernel.
195-
kernel = make_ahmc_kernel(spl.alg, ϵ)
195+
kernel = make_ahmc_kernel(spl, ϵ)
196196

197197
# Create initial transition and state.
198198
# Already perform one step since otherwise we don't get any statistics.
199199
t = AHMC.transition(rng, hamiltonian, kernel, z)
200200

201201
# Adaptation
202-
adaptor = AHMCAdaptor(spl.alg, hamiltonian.metric; ϵ=ϵ)
203-
if spl.alg isa AdaptiveHamiltonian
202+
adaptor = AHMCAdaptor(spl, hamiltonian.metric; ϵ=ϵ)
203+
if spl isa AdaptiveHamiltonian
204204
hamiltonian, kernel, _ = AHMC.adapt!(
205205
hamiltonian, kernel, adaptor, 1, nadapts, t.z.θ, t.stat.acceptance_rate
206206
)
@@ -224,7 +224,7 @@ end
224224
function AbstractMCMC.step(
225225
rng::Random.AbstractRNG,
226226
ldf::LogDensityFunction,
227-
spl::Sampler{<:Hamiltonian},
227+
spl::Hamiltonian,
228228
state::HMCState;
229229
nadapts=0,
230230
kwargs...,
@@ -236,7 +236,7 @@ function AbstractMCMC.step(
236236

237237
# Adaptation
238238
i = state.i + 1
239-
if spl.alg isa AdaptiveHamiltonian
239+
if spl isa AdaptiveHamiltonian
240240
hamiltonian, kernel, _ = AHMC.adapt!(
241241
hamiltonian,
242242
state.kernel,
@@ -276,7 +276,7 @@ function get_hamiltonian(model, spl, vi, state, n)
276276
# using leafcontext(model.context) so could we just remove the argument
277277
# entirely?)
278278
DynamicPPL.SamplingContext(spl, DynamicPPL.leafcontext(model.context));
279-
adtype=spl.alg.adtype,
279+
adtype=spl.adtype,
280280
)
281281
lp_func = Base.Fix1(LogDensityProblems.logdensity, ldf)
282282
lp_grad_func = Base.Fix1(LogDensityProblems.logdensity_and_gradient, ldf)
@@ -441,17 +441,17 @@ getmetricT(::NUTS{<:Any,metricT}) where {metricT} = metricT
441441
##### HMC core functions
442442
#####
443443

444-
getstepsize(sampler::Sampler{<:Hamiltonian}, state) = sampler.alg.ϵ
445-
getstepsize(sampler::Sampler{<:AdaptiveHamiltonian}, state) = AHMC.getϵ(state.adaptor)
444+
getstepsize(sampler::Hamiltonian, state) = sampler.ϵ
445+
getstepsize(sampler::AdaptiveHamiltonian, state) = AHMC.getϵ(state.adaptor)
446446
function getstepsize(
447-
sampler::Sampler{<:AdaptiveHamiltonian},
447+
sampler::AdaptiveHamiltonian,
448448
state::HMCState{TV,TKernel,THam,PhType,AHMC.Adaptation.NoAdaptation},
449449
) where {TV,TKernel,THam,PhType}
450450
return state.kernel.τ.integrator.ϵ
451451
end
452452

453-
gen_metric(dim::Int, spl::Sampler{<:Hamiltonian}, state) = AHMC.UnitEuclideanMetric(dim)
454-
function gen_metric(dim::Int, spl::Sampler{<:AdaptiveHamiltonian}, state)
453+
gen_metric(dim::Int, spl::Hamiltonian, state) = AHMC.UnitEuclideanMetric(dim)
454+
function gen_metric(dim::Int, spl::AdaptiveHamiltonian, state)
455455
return AHMC.renew(state.hamiltonian.metric, AHMC.getM⁻¹(state.adaptor.pc))
456456
end
457457

@@ -476,13 +476,11 @@ end
476476
####
477477
#### Compiler interface, i.e. tilde operators.
478478
####
479-
function DynamicPPL.assume(
480-
rng, ::Sampler{<:Hamiltonian}, dist::Distribution, vn::VarName, vi
481-
)
479+
function DynamicPPL.assume(rng, ::Hamiltonian, dist::Distribution, vn::VarName, vi)
482480
return DynamicPPL.assume(dist, vn, vi)
483481
end
484482

485-
function DynamicPPL.observe(::Sampler{<:Hamiltonian}, d::Distribution, value, vi)
483+
function DynamicPPL.observe(::Hamiltonian, d::Distribution, value, vi)
486484
return DynamicPPL.observe(d, value, vi)
487485
end
488486

test/ad.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -245,18 +245,18 @@ end
245245
# the tilde-pipeline and thus change the code executed during model
246246
# evaluation.
247247
@testset "adtype=$adtype" for adtype in ADTYPES
248-
@testset "alg=$alg" for alg in [
248+
@testset "spl=$spl" for spl in [
249249
HMC(0.1, 10; adtype=adtype),
250250
HMCDA(0.8, 0.75; adtype=adtype),
251251
NUTS(1000, 0.8; adtype=adtype),
252252
SGHMC(; learning_rate=0.02, momentum_decay=0.5, adtype=adtype),
253253
SGLD(; stepsize=PolynomialStepsize(0.25), adtype=adtype),
254254
]
255-
@info "Testing AD for $alg"
255+
@info "Testing AD for $spl"
256256

257257
@testset "model=$(model.f)" for model in DEMO_MODELS
258258
rng = StableRNG(123)
259-
ctx = DynamicPPL.SamplingContext(rng, DynamicPPL.Sampler(alg))
259+
ctx = DynamicPPL.SamplingContext(rng, spl)
260260
@test run_ad(model, adtype; context=ctx, test=true, benchmark=false) isa Any
261261
end
262262
end
@@ -283,7 +283,7 @@ end
283283
model, varnames, deepcopy(global_vi)
284284
)
285285
rng = StableRNG(123)
286-
ctx = DynamicPPL.SamplingContext(rng, DynamicPPL.Sampler(HMC(0.1, 10)))
286+
ctx = DynamicPPL.SamplingContext(rng, HMC(0.1, 10))
287287
@test run_ad(model, adtype; context=ctx, test=true, benchmark=false) isa Any
288288
end
289289
end

test/mcmc/Inference.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ using Turing
7676

7777
# run sampler: progress logging should be disabled and
7878
# it should return a Chains object
79-
sampler = Sampler(HMC(0.1, 7))
79+
sampler = HMC(0.1, 7)
8080
chains = sample(StableRNG(seed), gdemo_default, sampler, MCMCThreads(), 10, 4)
8181
@test chains isa MCMCChains.Chains
8282
end

0 commit comments

Comments
 (0)