Skip to content

BreadCrumbs interface: an easier way to feed pigeons #99

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/Pigeons.jl
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ const use_auto_exec_folder = ""

include("includes.jl")

export pigeons, Inputs, PT,
export pigeons, Inputs, PT, BreadCrumbs,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good to export it, but then the PR should contain documentation (in code and in the website).

# for running jobs:
ChildProcess, MPI,
# targets:
Expand Down
3 changes: 3 additions & 0 deletions src/api.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,6 @@ pigeons(; on = ThisProcess(), args...) =
pigeons(Inputs(; args...), on)

pigeons(pt_arguments, ::ThisProcess) = pigeons(PT(pt_arguments))

pigeons(bc::BreadCrumbs; on = ThisProcess(), args...) =
pigeons(Inputs(; target=BreadCrumbsTarget(bc), args...), on)
1 change: 1 addition & 0 deletions src/includes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ include("recorders/NonReproducible.jl")
include("recorders/LogSum.jl")
include("recorders/DiskRecorder.jl")
include("recorders/@record_if_requested!.jl")
include("pt/BreadCrumbs.jl")
include("pt/report.jl")
include("pt/plots.jl")
include("pt/Iterators.jl")
Expand Down
62 changes: 62 additions & 0 deletions src/pt/BreadCrumbs.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
"""
A struct that provides a basic, user-friendly interface to Pigeons. Only two inputs
are required, in positional order:
$FIELDS

!!! note

The PT state is initialized using a random sample from the reference.
"""
struct BreadCrumbs{TRefDist <: Distributions.Distribution, TTarget}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you use @auto instead? Otherwise the standard out gets cluttered with pages of mostly useless type information when showing the stack trace.

Also ref should be more general. We don't want Distributions to only work in BreadCrumbs. It should be seamless with the reference = .. method as well. (unless we remove the reference = .. method; which I don't think is a good idea since it works well for PPLs).

Conversely, things that can be fed into reference = .. should work in the target_log_potential argument. Similarly things that gets fed into target = should work in target_log_potential (I think this is already the case?).

"""A function that evaluates the target log potential"""
target_log_potential::TTarget
"""A Distributions.jl distribution used as reference"""
reference_distribution::TRefDist
end

# Target for a BreadCrumbs input
struct BreadCrumbsTarget{TBC <: BreadCrumbs}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems superfluous?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I.e. instead can we write the dispatches on Distributions types directly?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah! I think you're right! At least for the reference. I don't think people are thinking of passing the loglikelihood as a Distribution on the data given the parameter. That would also be very restrictive given that Distributions.jl only has simple models.

bc::TBC
end
function (bct::BreadCrumbsTarget)(x)
return if insupport(bct.bc.reference_distribution, x)
bct.bc.target_log_potential(x)
else
eltype(bct.bc.reference_distribution)(-Inf)
end
end
default_explorer(::BreadCrumbsTarget) = SliceSampler()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SliceSampler() is already the global default, no need to have that line.


# initialization
# general case
function initialization(bct::BreadCrumbsTarget, rng::AbstractRNG, ::Int)
rand(rng, bct.bc.reference_distribution)
end
# univariate case: need to wrap in vector to make the state mutable
function initialization(
bct::TBCT,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isn't this equivalent to bct::Distributions.UnivariateDistribution? etc for the other arguments

rng::AbstractRNG,
::Int
) where {TRD<:Distributions.UnivariateDistribution, TBC<:BreadCrumbs{TRD}, TBCT<:BreadCrumbsTarget{TBC}}
[rand(rng, bct.bc.reference_distribution)]
end

# reference for a BreadCrumbs input
struct BreadCrumbsReference{TBC <: BreadCrumbs}
bc::TBC
end
(bcr::BreadCrumbsReference)(x) = logpdf(bcr.bc.reference_distribution, x)
default_reference(bct::BreadCrumbsTarget) = BreadCrumbsReference(bct.bc)

# sampling from the reference
# general case
sample_iid!(bcr::BreadCrumbsReference, replica, shared) =
rand!(replica.rng, bcr.bc.reference_distribution, replica.state)
# univariate case
function sample_iid!(
bcr::TBCR,
replica,
shared
) where {TRD<:Distributions.UnivariateDistribution, TBC<:BreadCrumbs{TRD}, TBCR<:BreadCrumbsReference{TBC}}
replica.state[] = rand(rng, bcr.bc.reference_distribution)
end
17 changes: 17 additions & 0 deletions test/test_BreadCrumbs.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
using MCMCChains

@testset "Multivariate BreadCrumbs" begin
function unid_log_potential(x; n_trials=100, n_successes=50)
p = prod(x)
return n_successes*log(p) + (n_trials-n_successes)*log1p(-p)
end
ref_dist = product_distribution(Uniform(), Uniform())
pt = pigeons(
BreadCrumbs(unid_log_potential, ref_dist),
n_rounds = 12,
record = [traces]
)

# collect the statistics and convert to MCMCChains' Chains
samples = Chains(sample_array(pt), variable_names(pt))
end
2 changes: 1 addition & 1 deletion test/test_blang_bridge.jl
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@alexandrebouchard unrelated to this PR but I had to change this limit so that the test would pass. Are you ok with lowering this limit?

Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,6 @@ end
# NB: 10 chains runs out of memory in CI... reducing number of chains
n_restarts = n_tempered_restarts(pt)
global_barrier = Pigeons.global_barrier(pt.shared.tempering)
@test n_restarts > 180
@test n_restarts > 160
@test abs(global_barrier - 0.7) < 0.1
end