diff --git a/src/Pigeons.jl b/src/Pigeons.jl index 9fe5f3956..4fcc0c35c 100644 --- a/src/Pigeons.jl +++ b/src/Pigeons.jl @@ -63,7 +63,7 @@ const use_auto_exec_folder = "" include("includes.jl") -export pigeons, Inputs, PT, +export pigeons, Inputs, PT, BreadCrumbs, # for running jobs: ChildProcess, MPI, # targets: diff --git a/src/api.jl b/src/api.jl index 632ee7978..ea01e00ee 100644 --- a/src/api.jl +++ b/src/api.jl @@ -17,3 +17,6 @@ pigeons(; on = ThisProcess(), args...) = pigeons(Inputs(; args...), on) pigeons(pt_arguments, ::ThisProcess) = pigeons(PT(pt_arguments)) + +pigeons(bc::BreadCrumbs; on = ThisProcess(), args...) = + pigeons(Inputs(; target=BreadCrumbsTarget(bc), args...), on) diff --git a/src/includes.jl b/src/includes.jl index 2d8bfe148..5b1786834 100644 --- a/src/includes.jl +++ b/src/includes.jl @@ -18,6 +18,7 @@ include("recorders/NonReproducible.jl") include("recorders/LogSum.jl") include("recorders/DiskRecorder.jl") include("recorders/@record_if_requested!.jl") +include("pt/BreadCrumbs.jl") include("pt/report.jl") include("pt/plots.jl") include("pt/Iterators.jl") diff --git a/src/pt/BreadCrumbs.jl b/src/pt/BreadCrumbs.jl new file mode 100644 index 000000000..1764051f7 --- /dev/null +++ b/src/pt/BreadCrumbs.jl @@ -0,0 +1,62 @@ +""" +A struct that provides a basic, user-friendly interface to Pigeons. Only two inputs +are required, in positional order: +$FIELDS + +!!! note + + The PT state is initialized using a random sample from the reference. +""" +struct BreadCrumbs{TRefDist <: Distributions.Distribution, TTarget} + """A function that evaluates the target log potential""" + target_log_potential::TTarget + """A Distributions.jl distribution used as reference""" + reference_distribution::TRefDist +end + +# Target for a BreadCrumbs input +struct BreadCrumbsTarget{TBC <: BreadCrumbs} + bc::TBC +end +function (bct::BreadCrumbsTarget)(x) + return if insupport(bct.bc.reference_distribution, x) + bct.bc.target_log_potential(x) + else + eltype(bct.bc.reference_distribution)(-Inf) + end +end +default_explorer(::BreadCrumbsTarget) = SliceSampler() + +# initialization +# general case +function initialization(bct::BreadCrumbsTarget, rng::AbstractRNG, ::Int) + rand(rng, bct.bc.reference_distribution) +end +# univariate case: need to wrap in vector to make the state mutable +function initialization( + bct::TBCT, + rng::AbstractRNG, + ::Int + ) where {TRD<:Distributions.UnivariateDistribution, TBC<:BreadCrumbs{TRD}, TBCT<:BreadCrumbsTarget{TBC}} + [rand(rng, bct.bc.reference_distribution)] +end + +# reference for a BreadCrumbs input +struct BreadCrumbsReference{TBC <: BreadCrumbs} + bc::TBC +end +(bcr::BreadCrumbsReference)(x) = logpdf(bcr.bc.reference_distribution, x) +default_reference(bct::BreadCrumbsTarget) = BreadCrumbsReference(bct.bc) + +# sampling from the reference +# general case +sample_iid!(bcr::BreadCrumbsReference, replica, shared) = + rand!(replica.rng, bcr.bc.reference_distribution, replica.state) +# univariate case +function sample_iid!( + bcr::TBCR, + replica, + shared + ) where {TRD<:Distributions.UnivariateDistribution, TBC<:BreadCrumbs{TRD}, TBCR<:BreadCrumbsReference{TBC}} + replica.state[] = rand(rng, bcr.bc.reference_distribution) +end diff --git a/test/test_BreadCrumbs.jl b/test/test_BreadCrumbs.jl new file mode 100644 index 000000000..ee8e72e02 --- /dev/null +++ b/test/test_BreadCrumbs.jl @@ -0,0 +1,17 @@ +using MCMCChains + +@testset "Multivariate BreadCrumbs" begin + function unid_log_potential(x; n_trials=100, n_successes=50) + p = prod(x) + return n_successes*log(p) + (n_trials-n_successes)*log1p(-p) + end + ref_dist = product_distribution(Uniform(), Uniform()) + pt = pigeons( + BreadCrumbs(unid_log_potential, ref_dist), + n_rounds = 12, + record = [traces] + ) + + # collect the statistics and convert to MCMCChains' Chains + samples = Chains(sample_array(pt), variable_names(pt)) +end \ No newline at end of file diff --git a/test/test_blang_bridge.jl b/test/test_blang_bridge.jl index 983bc9cf7..f66ae3c5d 100644 --- a/test/test_blang_bridge.jl +++ b/test/test_blang_bridge.jl @@ -51,6 +51,6 @@ end # NB: 10 chains runs out of memory in CI... reducing number of chains n_restarts = n_tempered_restarts(pt) global_barrier = Pigeons.global_barrier(pt.shared.tempering) - @test n_restarts > 180 + @test n_restarts > 160 @test abs(global_barrier - 0.7) < 0.1 end \ No newline at end of file