diff --git a/Project.toml b/Project.toml index 9bb2749e..8c38f156 100644 --- a/Project.toml +++ b/Project.toml @@ -3,7 +3,7 @@ uuid = "80f14c24-f653-4e6a-9b94-39d6b0f70001" keywords = ["markov chain monte carlo", "probabilistic programming"] license = "MIT" desc = "A lightweight interface for common MCMC methods." -version = "5.6.1" +version = "5.7.0" [deps] BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66" diff --git a/src/interface.jl b/src/interface.jl index b58ced99..e31324dc 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -19,6 +19,20 @@ then `reduce(chainscat, c)` is called. chainsstack(c) = c chainsstack(c::AbstractVector{<:AbstractChains}) = reduce(chainscat, c) +""" + getadtype(s::AbstractSampler) + getadtype(m::AbstractModel, s::AbstractSampler) + +Specify the `ADTypes.AbstractADType` to be used when sampling from model `m` using sampler `s`. + +If the model is not relevant, then the implementation of AbstractSampler can +directly overload the single-argument method `getadtype(s::AbstractSampler)`. + +By default, this returns `nothing`. +""" +getadtype(::AbstractSampler) = nothing +getadtype(::AbstractModel, spl::AbstractSampler) = getadtype(spl) + """ bundle_samples(samples, model, sampler, state, chain_type[; kwargs...])