From d3809f7e723223646fa6344afed051f0f99d8b71 Mon Sep 17 00:00:00 2001 From: Miguel Biron-Lattes Date: Fri, 28 Jul 2023 12:36:52 -0700 Subject: [PATCH 1/3] BreadCrumbs interface --- src/Pigeons.jl | 2 +- src/api.jl | 3 +++ src/includes.jl | 1 + src/pt/BreadCrumbs.jl | 56 +++++++++++++++++++++++++++++++++++++++ test/test_BreadCrumbs.jl | 21 +++++++++++++++ test/test_blang_bridge.jl | 2 +- 6 files changed, 83 insertions(+), 2 deletions(-) create mode 100644 src/pt/BreadCrumbs.jl create mode 100644 test/test_BreadCrumbs.jl diff --git a/src/Pigeons.jl b/src/Pigeons.jl index 9fe5f3956..4fcc0c35c 100644 --- a/src/Pigeons.jl +++ b/src/Pigeons.jl @@ -63,7 +63,7 @@ const use_auto_exec_folder = "" include("includes.jl") -export pigeons, Inputs, PT, +export pigeons, Inputs, PT, BreadCrumbs, # for running jobs: ChildProcess, MPI, # targets: diff --git a/src/api.jl b/src/api.jl index 632ee7978..ea01e00ee 100644 --- a/src/api.jl +++ b/src/api.jl @@ -17,3 +17,6 @@ pigeons(; on = ThisProcess(), args...) = pigeons(Inputs(; args...), on) pigeons(pt_arguments, ::ThisProcess) = pigeons(PT(pt_arguments)) + +pigeons(bc::BreadCrumbs; on = ThisProcess(), args...) = + pigeons(Inputs(; target=BreadCrumbsTarget(bc), args...), on) diff --git a/src/includes.jl b/src/includes.jl index 2d8bfe148..5b1786834 100644 --- a/src/includes.jl +++ b/src/includes.jl @@ -18,6 +18,7 @@ include("recorders/NonReproducible.jl") include("recorders/LogSum.jl") include("recorders/DiskRecorder.jl") include("recorders/@record_if_requested!.jl") +include("pt/BreadCrumbs.jl") include("pt/report.jl") include("pt/plots.jl") include("pt/Iterators.jl") diff --git a/src/pt/BreadCrumbs.jl b/src/pt/BreadCrumbs.jl new file mode 100644 index 000000000..4172361fb --- /dev/null +++ b/src/pt/BreadCrumbs.jl @@ -0,0 +1,56 @@ +""" +A struct that provides a basic, user-friendly interface to Pigeons. Only two inputs +are required, in positional order: +$FIELDS + +!!! note + + The PT state is initialized using a random sample from the reference. +""" +struct BreadCrumbs{TRefDist <: Distributions.Distribution, TTarget} + """A function that evaluates the target log potential""" + target_log_potential::TTarget + """A Distributions.jl distribution used as reference""" + reference_distribution::TRefDist +end + +# Target for a BreadCrumbs input +struct BreadCrumbsTarget{TBC <: BreadCrumbs} + bc::TBC +end +(bct::BreadCrumbsTarget)(x) = bct.bc.target_log_potential(x) +default_explorer(::BreadCrumbsTarget) = SliceSampler() + +# initialization +# general case +function initialization(bct::BreadCrumbsTarget, rng::AbstractRNG, ::Int) + rand(rng, bct.bc.reference_distribution) +end +# univariate case: need to wrap in vector to make the state mutable +function initialization( + bct::TBCT, + rng::AbstractRNG, + ::Int + ) where {TRD<:Distributions.UnivariateDistribution, TBC<:BreadCrumbs{TRD}, TBCT<:BreadCrumbsTarget{TBC}} + [rand(rng, bct.bc.reference_distribution)] +end + +# reference for a BreadCrumbs input +struct BreadCrumbsReference{TBC <: BreadCrumbs} + bc::TBC +end +(bcr::BreadCrumbsReference)(x) = logpdf(bcr.bc.reference_distribution, x) +default_reference(bct::BreadCrumbsTarget) = BreadCrumbsReference(bct.bc) + +# sampling from the reference +# general case +sample_iid!(bcr::BreadCrumbsReference, replica, shared) = + rand!(replica.rng, bcr.bc.reference_distribution, replica.state) +# univariate case +function sample_iid!( + bcr::TBCR, + replica, + shared + ) where {TRD<:Distributions.UnivariateDistribution, TBC<:BreadCrumbs{TRD}, TBCR<:BreadCrumbsReference{TBC}} + replica.state[] = rand(rng, bcr.bc.reference_distribution) +end diff --git a/test/test_BreadCrumbs.jl b/test/test_BreadCrumbs.jl new file mode 100644 index 000000000..eb37ba39b --- /dev/null +++ b/test/test_BreadCrumbs.jl @@ -0,0 +1,21 @@ +using MCMCChains + +@testset "Multivariate BreadCrumbs" begin + function unid_log_potential(x; n_trials=100, n_successes=50) + p1, p2 = x + if !(0 < p1 < 1) || !(0 < p2 < 1) + return eltype(x)(-Inf) + end + p = p1 * p2 + return n_successes*log(p) + (n_trials-n_successes)*log1p(-p) + end + ref_dist = product_distribution(Uniform(), Uniform()) + pt = pigeons( + BreadCrumbs(unid_log_potential, ref_dist), + n_rounds = 12, + record = [traces] + ) + + # collect the statistics and convert to MCMCChains' Chains + samples = Chains(sample_array(pt), variable_names(pt)) +end \ No newline at end of file diff --git a/test/test_blang_bridge.jl b/test/test_blang_bridge.jl index 983bc9cf7..f66ae3c5d 100644 --- a/test/test_blang_bridge.jl +++ b/test/test_blang_bridge.jl @@ -51,6 +51,6 @@ end # NB: 10 chains runs out of memory in CI... reducing number of chains n_restarts = n_tempered_restarts(pt) global_barrier = Pigeons.global_barrier(pt.shared.tempering) - @test n_restarts > 180 + @test n_restarts > 160 @test abs(global_barrier - 0.7) < 0.1 end \ No newline at end of file From 8ccaf28f01f7d6721c3ed969c2cc6e266bfc59d4 Mon Sep 17 00:00:00 2001 From: Miguel Biron-Lattes Date: Fri, 28 Jul 2023 13:00:39 -0700 Subject: [PATCH 2/3] check x in support of ref-dist when evaluating target --- src/pt/BreadCrumbs.jl | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/pt/BreadCrumbs.jl b/src/pt/BreadCrumbs.jl index 4172361fb..1764051f7 100644 --- a/src/pt/BreadCrumbs.jl +++ b/src/pt/BreadCrumbs.jl @@ -18,7 +18,13 @@ end struct BreadCrumbsTarget{TBC <: BreadCrumbs} bc::TBC end -(bct::BreadCrumbsTarget)(x) = bct.bc.target_log_potential(x) +function (bct::BreadCrumbsTarget)(x) + return if insupport(bct.bc.reference_distribution, x) + bct.bc.target_log_potential(x) + else + eltype(bct.bc.reference_distribution)(-Inf) + end +end default_explorer(::BreadCrumbsTarget) = SliceSampler() # initialization From e9bb0796e9dec0824f3cd6c8ff2ce401a9696309 Mon Sep 17 00:00:00 2001 From: Miguel Biron-Lattes Date: Fri, 28 Jul 2023 13:01:14 -0700 Subject: [PATCH 3/3] simplify BreadCrumbs test --- test/test_BreadCrumbs.jl | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/test/test_BreadCrumbs.jl b/test/test_BreadCrumbs.jl index eb37ba39b..ee8e72e02 100644 --- a/test/test_BreadCrumbs.jl +++ b/test/test_BreadCrumbs.jl @@ -2,11 +2,7 @@ using MCMCChains @testset "Multivariate BreadCrumbs" begin function unid_log_potential(x; n_trials=100, n_successes=50) - p1, p2 = x - if !(0 < p1 < 1) || !(0 < p2 < 1) - return eltype(x)(-Inf) - end - p = p1 * p2 + p = prod(x) return n_successes*log(p) + (n_trials-n_successes)*log1p(-p) end ref_dist = product_distribution(Uniform(), Uniform())