-
Notifications
You must be signed in to change notification settings - Fork 12
BreadCrumbs interface: an easier way to feed pigeons #99
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
""" | ||
A struct that provides a basic, user-friendly interface to Pigeons. Only two inputs | ||
are required, in positional order: | ||
$FIELDS | ||
|
||
!!! note | ||
|
||
The PT state is initialized using a random sample from the reference. | ||
""" | ||
struct BreadCrumbs{TRefDist <: Distributions.Distribution, TTarget} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you use @auto instead? Otherwise the standard out gets cluttered with pages of mostly useless type information when showing the stack trace. Also ref should be more general. We don't want Distributions to only work in BreadCrumbs. It should be seamless with the Conversely, things that can be fed into |
||
"""A function that evaluates the target log potential""" | ||
target_log_potential::TTarget | ||
"""A Distributions.jl distribution used as reference""" | ||
reference_distribution::TRefDist | ||
end | ||
|
||
# Target for a BreadCrumbs input | ||
struct BreadCrumbsTarget{TBC <: BreadCrumbs} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This seems superfluous? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I.e. instead can we write the dispatches on Distributions types directly? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah! I think you're right! At least for the reference. I don't think people are thinking of passing the loglikelihood as a Distribution on the data given the parameter. That would also be very restrictive given that Distributions.jl only has simple models. |
||
bc::TBC | ||
end | ||
function (bct::BreadCrumbsTarget)(x) | ||
return if insupport(bct.bc.reference_distribution, x) | ||
bct.bc.target_log_potential(x) | ||
else | ||
eltype(bct.bc.reference_distribution)(-Inf) | ||
end | ||
end | ||
default_explorer(::BreadCrumbsTarget) = SliceSampler() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. SliceSampler() is already the global default, no need to have that line. |
||
|
||
# initialization | ||
# general case | ||
function initialization(bct::BreadCrumbsTarget, rng::AbstractRNG, ::Int) | ||
rand(rng, bct.bc.reference_distribution) | ||
end | ||
# univariate case: need to wrap in vector to make the state mutable | ||
function initialization( | ||
bct::TBCT, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. isn't this equivalent to |
||
rng::AbstractRNG, | ||
::Int | ||
) where {TRD<:Distributions.UnivariateDistribution, TBC<:BreadCrumbs{TRD}, TBCT<:BreadCrumbsTarget{TBC}} | ||
[rand(rng, bct.bc.reference_distribution)] | ||
end | ||
|
||
# reference for a BreadCrumbs input | ||
struct BreadCrumbsReference{TBC <: BreadCrumbs} | ||
bc::TBC | ||
end | ||
(bcr::BreadCrumbsReference)(x) = logpdf(bcr.bc.reference_distribution, x) | ||
default_reference(bct::BreadCrumbsTarget) = BreadCrumbsReference(bct.bc) | ||
|
||
# sampling from the reference | ||
# general case | ||
sample_iid!(bcr::BreadCrumbsReference, replica, shared) = | ||
rand!(replica.rng, bcr.bc.reference_distribution, replica.state) | ||
# univariate case | ||
function sample_iid!( | ||
bcr::TBCR, | ||
replica, | ||
shared | ||
) where {TRD<:Distributions.UnivariateDistribution, TBC<:BreadCrumbs{TRD}, TBCR<:BreadCrumbsReference{TBC}} | ||
replica.state[] = rand(rng, bcr.bc.reference_distribution) | ||
end |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
using MCMCChains | ||
|
||
@testset "Multivariate BreadCrumbs" begin | ||
function unid_log_potential(x; n_trials=100, n_successes=50) | ||
p = prod(x) | ||
return n_successes*log(p) + (n_trials-n_successes)*log1p(-p) | ||
end | ||
ref_dist = product_distribution(Uniform(), Uniform()) | ||
pt = pigeons( | ||
BreadCrumbs(unid_log_potential, ref_dist), | ||
n_rounds = 12, | ||
record = [traces] | ||
) | ||
|
||
# collect the statistics and convert to MCMCChains' Chains | ||
samples = Chains(sample_array(pt), variable_names(pt)) | ||
end |
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @alexandrebouchard unrelated to this PR but I had to change this limit so that the test would pass. Are you ok with lowering this limit? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good to export it, but then the PR should contain documentation (in code and in the website).