Skip to content

Commit 60ba25d

Browse files
authored
Merge pull request #102 from Julia-Tempering/generalize-initialization
add a DistributionLogPotential
2 parents c0d052c + f1b3a9b commit 60ba25d

File tree

6 files changed

+79
-2
lines changed

6 files changed

+79
-2
lines changed

src/Pigeons.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,8 @@ include("includes.jl")
6666
export pigeons, Inputs, PT,
6767
# for running jobs:
6868
ChildProcess, MPI,
69+
# references:
70+
DistributionLogPotential,
6971
# targets:
7072
TuringLogPotential, StanLogPotential,
7173
# some examples

src/includes.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ include("paths/InterpolatedLogPotential.jl")
2626
include("paths/InterpolatingPath.jl")
2727
include("variational/variational.jl")
2828
include("pt/Inputs.jl")
29+
include("targets/DistributionLogPotential.jl")
2930
include("targets/StreamTarget.jl")
3031
include("pt/Shared.jl")
3132
include("swap/swap_graphs.jl")

src/replicas/replicas.jl

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ end
8787
function _create_locals(my_global_indices, inputs::Inputs, shared::Shared, ::Nothing)
8888
master_rng = SplittableRandom(inputs.seed)
8989
split_rngs = split_slice(my_global_indices, master_rng)
90-
states = [initialization(inputs.target, split_rngs[i], my_global_indices[i]) for i in eachindex(split_rngs)]
90+
states = [initialization(inputs, split_rngs[i], my_global_indices[i]) for i in eachindex(split_rngs)]
9191
recorders = [create_recorders(inputs, shared) for i in eachindex(split_rngs)]
9292
return Replica.(
9393
states,
@@ -96,3 +96,25 @@ function _create_locals(my_global_indices, inputs::Inputs, shared::Shared, ::Not
9696
recorders,
9797
my_global_indices) # <- replica indices
9898
end
99+
100+
# default method: defer to user-provided method for their target
101+
initialization(inp::Inputs, args...) = initialization(inp.target, args...)
102+
103+
# generic method for distribution-type references: sample iid for all replicas
104+
function initialization(
105+
inp::Inputs{T, V, E, R},
106+
rng::AbstractRNG,
107+
::Int
108+
) where {T, V, E, R <: DistributionLogPotential}
109+
rand(rng, inp.reference.dist)
110+
end
111+
112+
# for univariate references we need to wrap in vector
113+
function initialization(
114+
inp::Inputs{T, V, E, R},
115+
rng::AbstractRNG,
116+
::Int
117+
) where {T, V, E, R <: DistributionLogPotential{<:UnivariateDistribution}}
118+
[rand(rng, inp.reference.dist)]
119+
end
120+
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
"""
2+
Provides a reference type for Pigeons based on an encapsulated Distribution type.
3+
$FIELDS
4+
"""
5+
@auto struct DistributionLogPotential{D<:Distribution}
6+
"""The encapsulated distribution."""
7+
dist::D
8+
end
9+
10+
# evaluate the log density: general case
11+
(ref::DistributionLogPotential)(x) = logpdf(ref.dist, x)
12+
13+
# univariate case
14+
(ref::DistributionLogPotential{<:UnivariateDistribution})(x) = logpdf(ref.dist, first(x))
15+
16+
# iid sampling
17+
# general case
18+
function sample_iid!(ref::DistributionLogPotential, replica, shared)
19+
rand!(replica.rng, ref.dist, replica.state)
20+
end
21+
22+
# univariate case
23+
function sample_iid!(ref::DistributionLogPotential{D}, replica, shared) where {D<:UnivariateDistribution}
24+
replica.state[begin] = rand(replica.rng, ref.dist)
25+
end

src/targets/target.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,4 +92,4 @@ sample_iid!(reference_log_potential::InterpolatedLogPotential, replica, shared)
9292
sample_iid!(reference_log_potential.path.ref, replica, shared)
9393
else
9494
error()
95-
end
95+
end

test/test_DistributionLogPotential.jl

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
using MCMCChains
2+
3+
@testset "DistributionLogPotential" begin
4+
@testset "Multivariate" begin
5+
function unid_log_potential(x; n_trials=100, n_successes=50)
6+
p1, p2 = x
7+
((0 <= p1 <= 1) && (0 <= p2 <= 1)) || return typeof(p1)(-Inf)
8+
p = p1 * p2
9+
return n_successes*log(p) + (n_trials-n_successes)*log1p(-p)
10+
end
11+
ref_dist = product_distribution([Uniform(), Uniform()])
12+
pt = pigeons(
13+
target = unid_log_potential,
14+
reference = DistributionLogPotential(ref_dist),
15+
record = [traces]
16+
)
17+
@show Chains(sample_array(pt), variable_names(pt))
18+
end
19+
@testset "Univariate" begin
20+
pt = pigeons(
21+
target = (x -> logpdf(Normal(3,1), x[begin])),
22+
reference = DistributionLogPotential(Normal(-3,1)),
23+
record = [traces]
24+
)
25+
@show Chains(sample_array(pt), variable_names(pt))
26+
end
27+
end

0 commit comments

Comments
 (0)