@@ -11,4 +11,46 @@ this wrapping occurs automatically.
11
11
"""
12
12
abstract type InferenceAlgorithm end
13
13
14
+ # TODO (penelopeysm): remove
14
15
DynamicPPL. default_chain_type (sampler:: Sampler{<:InferenceAlgorithm} ) = MCMCChains. Chains
16
+
17
+ """
18
+ update_sample_kwargs(spl::AbstractSampler, N::Integer, kwargs)
19
+ update_sample_kwargs(spl::InferenceAlgorithm, N::Integer, kwargs)
20
+ Some InferenceAlgorithm implementations carry additional information about
21
+ the keyword arguments that should be passed to `AbstractMCMC.sample`. This
22
+ function provides a hook for them to update the default keyword arguments.
23
+ The default implementation is for no changes to be made to `kwargs`.
24
+ """
25
+ function update_sample_kwargs (spl:: Sampler{<:InferenceAlgorithm} , N:: Integer , kwargs)
26
+ return update_sample_kwargs (spl. alg, N, kwargs)
27
+ end
28
+ update_sample_kwargs (:: AbstractSampler , N:: Integer , kwargs) = kwargs
29
+ update_sample_kwargs (:: InferenceAlgorithm , N:: Integer , kwargs) = kwargs
30
+
31
+ """
32
+ get_adtype(spl::AbstractSampler)
33
+ get_adtype(spl::InferenceAlgorithm)
34
+ Return the automatic differentiation (AD) backend to use for the sampler.
35
+ This is needed for constructing a LogDensityFunction.
36
+ By default, returns nothing, i.e. the LogDensityFunction that is constructed
37
+ will not know how to calculate its gradients.
38
+ If the sampler or algorithm requires gradient information, then this function
39
+ must return an `ADTypes.AbstractADType`.
40
+ """
41
+ get_adtype (:: AbstractSampler ) = nothing
42
+ get_adtype (:: InferenceAlgorithm ) = nothing
43
+ get_adtype (spl:: Sampler{<:InferenceAlgorithm} ) = get_adtype (spl. alg)
44
+
45
+ """
46
+ requires_unconstrained_space(sampler::AbstractSampler)
47
+ requires_unconstrained_space(sampler::InferenceAlgorithm)
48
+ Return `true` if the sampler / algorithm requires unconstrained space, and
49
+ `false` otherwise. This is used to determine whether the initial VarInfo
50
+ should be linked. Defaults to true.
51
+ """
52
+ requires_unconstrained_space (:: AbstractSampler ) = true
53
+ requires_unconstrained_space (:: InferenceAlgorithm ) = true
54
+ function requires_unconstrained_space (spl:: Sampler{<:InferenceAlgorithm} )
55
+ return requires_unconstrained_space (spl. alg)
56
+ end
0 commit comments