Skip to content

Commit bac2171

Browse files
authored
Add a evaluate function that returns logprior and loglikelihood (#247)
`_tempered_evaluate!!` returns updated `evaluation_env` and a NamedTuple of `logprior`, `loglikelihood` and `tempered_logjoint` (`tempered_logjoint = logprior + temperature * loglikelihood(x)`).
1 parent 290c5ef commit bac2171

File tree

4 files changed

+115
-1
lines changed

4 files changed

+115
-1
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "JuliaBUGS"
22
uuid = "ba9fb4c0-828e-4473-b6a1-cd2560fee5bf"
3-
version = "0.7.2"
3+
version = "0.7.3"
44

55
[deps]
66
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"

src/model.jl

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -539,3 +539,67 @@ function AbstractPPL.evaluate!!(model::BUGSModel, flattened_values::AbstractVect
539539
end
540540
return evaluation_env, logp
541541
end
542+
543+
"""
544+
_tempered_evaluate!!(model::BUGSModel, flattened_values::AbstractVector; temperature=1.0)
545+
546+
Evaluating the model with the given model parameter values, returns updated evaluation environment
547+
and a NamedTuple of logprior, loglikelihood and tempered logjoint (where tempered logjoint is the logjoint
548+
whose loglikelihood component scaled by the given temperature).
549+
"""
550+
function _tempered_evaluate!!(
551+
model::BUGSModel, flattened_values::AbstractVector; temperature=1.0
552+
)
553+
var_lengths = if model.transformed
554+
model.transformed_var_lengths
555+
else
556+
model.untransformed_var_lengths
557+
end
558+
559+
evaluation_env = deepcopy(model.evaluation_env)
560+
current_idx = 1
561+
logprior, loglikelihood = 0.0, 0.0
562+
for (i, vn) in enumerate(model.flattened_graph_node_data.sorted_nodes)
563+
is_stochastic = model.flattened_graph_node_data.is_stochastic_vals[i]
564+
is_observed = model.flattened_graph_node_data.is_observed_vals[i]
565+
node_function = model.flattened_graph_node_data.node_function_vals[i]
566+
loop_vars = model.flattened_graph_node_data.loop_vars_vals[i]
567+
if !is_stochastic
568+
value = node_function(evaluation_env, loop_vars)
569+
evaluation_env = BangBang.setindex!!(evaluation_env, value, vn)
570+
else
571+
dist = node_function(evaluation_env, loop_vars)
572+
if !is_observed
573+
l = var_lengths[vn]
574+
if model.transformed
575+
b = Bijectors.bijector(dist)
576+
b_inv = Bijectors.inverse(b)
577+
reconstructed_value = reconstruct(
578+
b_inv,
579+
dist,
580+
view(flattened_values, current_idx:(current_idx + l - 1)),
581+
)
582+
value, logjac = Bijectors.with_logabsdet_jacobian(
583+
b_inv, reconstructed_value
584+
)
585+
else
586+
value = reconstruct(
587+
dist, view(flattened_values, current_idx:(current_idx + l - 1))
588+
)
589+
logjac = 0.0
590+
end
591+
current_idx += l
592+
logprior += logpdf(dist, value) + logjac
593+
evaluation_env = BangBang.setindex!!(evaluation_env, value, vn)
594+
else
595+
loglikelihood += logpdf(dist, AbstractPPL.get(evaluation_env, vn))
596+
end
597+
end
598+
end
599+
return evaluation_env,
600+
(
601+
logprior=logprior,
602+
loglikelihood=loglikelihood,
603+
tempered_logjoint=logprior + temperature * loglikelihood,
604+
)
605+
end

test/model.jl

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
@testset "logprior and loglikelihood" begin
2+
@testset "Complex model with transformations" begin
3+
model_def = @bugs begin
4+
s[1] ~ InverseGamma(2, 3)
5+
s[2] ~ InverseGamma(2, 3)
6+
m[1] ~ Normal(0, sqrt(s[1]))
7+
m[2] ~ Normal(0, sqrt(s[2]))
8+
x[1:2] ~ MvNormal(m[1:2], Diagonal(s[1:2]))
9+
end
10+
11+
data = (; x=[1.0, 2.0])
12+
13+
model = compile(model_def, data)
14+
15+
params = rand(4)
16+
17+
b = Bijectors.bijector(InverseGamma(2, 3))
18+
b_inv = Bijectors.inverse(b)
19+
20+
log_prior_true = begin
21+
# parameter sorted: s[2], m[2], s[1], m[1]
22+
s1_inversed, logjac1 = Bijectors.with_logabsdet_jacobian(b_inv, params[3])
23+
s2_inversed, logjac2 = Bijectors.with_logabsdet_jacobian(b_inv, params[1])
24+
logpdf(InverseGamma(2, 3), s1_inversed) +
25+
logjac1 +
26+
logpdf(InverseGamma(2, 3), s2_inversed) +
27+
logjac2 +
28+
logpdf(Normal(0, sqrt(s1_inversed)), params[4]) +
29+
logpdf(Normal(0, sqrt(s2_inversed)), params[2])
30+
end
31+
32+
log_likelihood_true = begin
33+
s1_inversed = b_inv(params[3])
34+
s2_inversed = b_inv(params[1])
35+
logpdf(
36+
MvNormal([params[4], params[2]], Diagonal([s1_inversed, s2_inversed])),
37+
data.x,
38+
)
39+
end
40+
41+
_, (logprior, loglikelihood, tempered_logjoint) = JuliaBUGS._tempered_evaluate!!(
42+
model, params; temperature=2.0
43+
)
44+
45+
@test logprior log_prior_true rtol = 1E-6
46+
@test loglikelihood log_likelihood_true rtol = 1E-6
47+
@test tempered_logjoint log_prior_true + 2.0 * log_likelihood_true rtol = 1E-6
48+
end
49+
end

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ end
6060

6161
if test_group == "log_density" || test_group == "all"
6262
include("log_density.jl")
63+
include("model.jl")
6364
end
6465

6566
if test_group == "gibbs" || test_group == "all"

0 commit comments

Comments
 (0)