Skip to content

Commit f59ed0d

Browse files
committed
Add interface functions for InferenceAlgorithm
1 parent 557a306 commit f59ed0d

File tree

1 file changed

+42
-0
lines changed

1 file changed

+42
-0
lines changed

src/mcmc/algorithm.jl

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,46 @@ this wrapping occurs automatically.
1111
"""
1212
abstract type InferenceAlgorithm end
1313

14+
# TODO(penelopeysm): remove
1415
DynamicPPL.default_chain_type(sampler::Sampler{<:InferenceAlgorithm}) = MCMCChains.Chains
16+
17+
"""
18+
update_sample_kwargs(spl::AbstractSampler, N::Integer, kwargs)
19+
update_sample_kwargs(spl::InferenceAlgorithm, N::Integer, kwargs)
20+
Some InferenceAlgorithm implementations carry additional information about
21+
the keyword arguments that should be passed to `AbstractMCMC.sample`. This
22+
function provides a hook for them to update the default keyword arguments.
23+
The default implementation is for no changes to be made to `kwargs`.
24+
"""
25+
function update_sample_kwargs(spl::Sampler{<:InferenceAlgorithm}, N::Integer, kwargs)
26+
return update_sample_kwargs(spl.alg, N, kwargs)
27+
end
28+
update_sample_kwargs(::AbstractSampler, N::Integer, kwargs) = kwargs
29+
update_sample_kwargs(::InferenceAlgorithm, N::Integer, kwargs) = kwargs
30+
31+
"""
32+
get_adtype(spl::AbstractSampler)
33+
get_adtype(spl::InferenceAlgorithm)
34+
Return the automatic differentiation (AD) backend to use for the sampler.
35+
This is needed for constructing a LogDensityFunction.
36+
By default, returns nothing, i.e. the LogDensityFunction that is constructed
37+
will not know how to calculate its gradients.
38+
If the sampler or algorithm requires gradient information, then this function
39+
must return an `ADTypes.AbstractADType`.
40+
"""
41+
get_adtype(::AbstractSampler) = nothing
42+
get_adtype(::InferenceAlgorithm) = nothing
43+
get_adtype(spl::Sampler{<:InferenceAlgorithm}) = get_adtype(spl.alg)
44+
45+
"""
46+
requires_unconstrained_space(sampler::AbstractSampler)
47+
requires_unconstrained_space(sampler::InferenceAlgorithm)
48+
Return `true` if the sampler / algorithm requires unconstrained space, and
49+
`false` otherwise. This is used to determine whether the initial VarInfo
50+
should be linked. Defaults to true.
51+
"""
52+
requires_unconstrained_space(::AbstractSampler) = true
53+
requires_unconstrained_space(::InferenceAlgorithm) = true
54+
function requires_unconstrained_space(spl::Sampler{<:InferenceAlgorithm})
55+
return requires_unconstrained_space(spl.alg)
56+
end

0 commit comments

Comments
 (0)