@@ -11,4 +11,43 @@ this wrapping occurs automatically.
11
11
"""
12
12
abstract type InferenceAlgorithm end
13
13
14
- DynamicPPL. default_chain_type (sampler:: Sampler{<:InferenceAlgorithm} ) = MCMCChains. Chains
14
+ """
15
+ update_sample_kwargs(spl::AbstractSampler, N::Integer, kwargs)
16
+ update_sample_kwargs(spl::InferenceAlgorithm, N::Integer, kwargs)
17
+ Some InferenceAlgorithm implementations carry additional information about
18
+ the keyword arguments that should be passed to `AbstractMCMC.sample`. This
19
+ function provides a hook for them to update the default keyword arguments.
20
+ The default implementation is for no changes to be made to `kwargs`.
21
+ """
22
+ function update_sample_kwargs (spl:: Sampler{<:InferenceAlgorithm} , N:: Integer , kwargs)
23
+ return update_sample_kwargs (spl. alg, N, kwargs)
24
+ end
25
+ update_sample_kwargs (:: AbstractSampler , N:: Integer , kwargs) = kwargs
26
+ update_sample_kwargs (:: InferenceAlgorithm , N:: Integer , kwargs) = kwargs
27
+
28
+ """
29
+ get_adtype(spl::AbstractSampler)
30
+ get_adtype(spl::InferenceAlgorithm)
31
+ Return the automatic differentiation (AD) backend to use for the sampler.
32
+ This is needed for constructing a LogDensityFunction.
33
+ By default, returns nothing, i.e. the LogDensityFunction that is constructed
34
+ will not know how to calculate its gradients.
35
+ If the sampler or algorithm requires gradient information, then this function
36
+ must return an `ADTypes.AbstractADType`.
37
+ """
38
+ get_adtype (:: AbstractSampler ) = nothing
39
+ get_adtype (:: InferenceAlgorithm ) = nothing
40
+ get_adtype (spl:: Sampler{<:InferenceAlgorithm} ) = get_adtype (spl. alg)
41
+
42
+ """
43
+ requires_unconstrained_space(sampler::AbstractSampler)
44
+ requires_unconstrained_space(sampler::InferenceAlgorithm)
45
+ Return `true` if the sampler / algorithm requires unconstrained space, and
46
+ `false` otherwise. This is used to determine whether the initial VarInfo
47
+ should be linked. Defaults to true.
48
+ """
49
+ requires_unconstrained_space (:: AbstractSampler ) = true
50
+ requires_unconstrained_space (:: InferenceAlgorithm ) = true
51
+ function requires_unconstrained_space (spl:: Sampler{<:InferenceAlgorithm} )
52
+ return requires_unconstrained_space (spl. alg)
53
+ end
0 commit comments