Skip to content

Commit 782cfae

Browse files
committed
Add interface functions for InferenceAlgorithm
1 parent 557a306 commit 782cfae

File tree

1 file changed

+40
-1
lines changed

1 file changed

+40
-1
lines changed

src/mcmc/algorithm.jl

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,43 @@ this wrapping occurs automatically.
1111
"""
1212
abstract type InferenceAlgorithm end
1313

14-
DynamicPPL.default_chain_type(sampler::Sampler{<:InferenceAlgorithm}) = MCMCChains.Chains
14+
"""
15+
update_sample_kwargs(spl::AbstractSampler, N::Integer, kwargs)
16+
update_sample_kwargs(spl::InferenceAlgorithm, N::Integer, kwargs)
17+
Some InferenceAlgorithm implementations carry additional information about
18+
the keyword arguments that should be passed to `AbstractMCMC.sample`. This
19+
function provides a hook for them to update the default keyword arguments.
20+
The default implementation is for no changes to be made to `kwargs`.
21+
"""
22+
function update_sample_kwargs(spl::Sampler{<:InferenceAlgorithm}, N::Integer, kwargs)
23+
return update_sample_kwargs(spl.alg, N, kwargs)
24+
end
25+
update_sample_kwargs(::AbstractSampler, N::Integer, kwargs) = kwargs
26+
update_sample_kwargs(::InferenceAlgorithm, N::Integer, kwargs) = kwargs
27+
28+
"""
29+
get_adtype(spl::AbstractSampler)
30+
get_adtype(spl::InferenceAlgorithm)
31+
Return the automatic differentiation (AD) backend to use for the sampler.
32+
This is needed for constructing a LogDensityFunction.
33+
By default, returns nothing, i.e. the LogDensityFunction that is constructed
34+
will not know how to calculate its gradients.
35+
If the sampler or algorithm requires gradient information, then this function
36+
must return an `ADTypes.AbstractADType`.
37+
"""
38+
get_adtype(::AbstractSampler) = nothing
39+
get_adtype(::InferenceAlgorithm) = nothing
40+
get_adtype(spl::Sampler{<:InferenceAlgorithm}) = get_adtype(spl.alg)
41+
42+
"""
43+
requires_unconstrained_space(sampler::AbstractSampler)
44+
requires_unconstrained_space(sampler::InferenceAlgorithm)
45+
Return `true` if the sampler / algorithm requires unconstrained space, and
46+
`false` otherwise. This is used to determine whether the initial VarInfo
47+
should be linked. Defaults to true.
48+
"""
49+
requires_unconstrained_space(::AbstractSampler) = true
50+
requires_unconstrained_space(::InferenceAlgorithm) = true
51+
function requires_unconstrained_space(spl::Sampler{<:InferenceAlgorithm})
52+
return requires_unconstrained_space(spl.alg)
53+
end

0 commit comments

Comments
 (0)