From bfbf727f4e0e0132458fa8337b74ebe7a32d52cc Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Mon, 10 Feb 2025 11:38:04 +0000 Subject: [PATCH 1/2] Add getadtype function --- Project.toml | 2 +- src/interface.jl | 10 ++++++++++ 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 9bb2749e..8c38f156 100644 --- a/Project.toml +++ b/Project.toml @@ -3,7 +3,7 @@ uuid = "80f14c24-f653-4e6a-9b94-39d6b0f70001" keywords = ["markov chain monte carlo", "probabilistic programming"] license = "MIT" desc = "A lightweight interface for common MCMC methods." -version = "5.6.1" +version = "5.7.0" [deps] BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66" diff --git a/src/interface.jl b/src/interface.jl index b58ced99..2e2ba020 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -19,6 +19,16 @@ then `reduce(chainscat, c)` is called. chainsstack(c) = c chainsstack(c::AbstractVector{<:AbstractChains}) = reduce(chainscat, c) +""" + getadtype(sampler::AbstractSampler) + +If the sampler specifies an automatic differentiation (AD) backend to use, this +function should return the corresponding `ADTypes.AbstractADType`. + +By default, this returns `nothing`. +""" +getadtype(::AbstractSampler) = nothing + """ bundle_samples(samples, model, sampler, state, chain_type[; kwargs...]) From ae297c36b7bbb8c836b5c890338901d3fa8cc823 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Wed, 12 Feb 2025 13:00:06 +0000 Subject: [PATCH 2/2] Implement two-argument version --- src/interface.jl | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/src/interface.jl b/src/interface.jl index 2e2ba020..e31324dc 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -20,14 +20,18 @@ chainsstack(c) = c chainsstack(c::AbstractVector{<:AbstractChains}) = reduce(chainscat, c) """ - getadtype(sampler::AbstractSampler) + getadtype(s::AbstractSampler) + getadtype(m::AbstractModel, s::AbstractSampler) -If the sampler specifies an automatic differentiation (AD) backend to use, this -function should return the corresponding `ADTypes.AbstractADType`. +Specify the `ADTypes.AbstractADType` to be used when sampling from model `m` using sampler `s`. + +If the model is not relevant, then the implementation of AbstractSampler can +directly overload the single-argument method `getadtype(s::AbstractSampler)`. By default, this returns `nothing`. """ getadtype(::AbstractSampler) = nothing +getadtype(::AbstractModel, spl::AbstractSampler) = getadtype(spl) """ bundle_samples(samples, model, sampler, state, chain_type[; kwargs...])