Skip to content

Commit 859e2b9

Browse files
authored
A quick refactoring to reduce repeated code (#251)
use `_tempered_evalaute!!` function introduced in #247 to reduce duplicated code
1 parent bac2171 commit 859e2b9

File tree

2 files changed

+5
-48
lines changed

2 files changed

+5
-48
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "JuliaBUGS"
22
uuid = "ba9fb4c0-828e-4473-b6a1-cd2560fee5bf"
3-
version = "0.7.3"
3+
version = "0.7.4"
44

55
[deps]
66
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"

src/model.jl

Lines changed: 4 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -491,53 +491,10 @@ function AbstractPPL.evaluate!!(model::BUGSModel)
491491
end
492492

493493
function AbstractPPL.evaluate!!(model::BUGSModel, flattened_values::AbstractVector)
494-
var_lengths = if model.transformed
495-
model.transformed_var_lengths
496-
else
497-
model.untransformed_var_lengths
498-
end
499-
500-
evaluation_env = deepcopy(model.evaluation_env)
501-
current_idx = 1
502-
logp = 0.0
503-
for (i, vn) in enumerate(model.flattened_graph_node_data.sorted_nodes)
504-
is_stochastic = model.flattened_graph_node_data.is_stochastic_vals[i]
505-
is_observed = model.flattened_graph_node_data.is_observed_vals[i]
506-
node_function = model.flattened_graph_node_data.node_function_vals[i]
507-
loop_vars = model.flattened_graph_node_data.loop_vars_vals[i]
508-
if !is_stochastic
509-
value = node_function(evaluation_env, loop_vars)
510-
evaluation_env = BangBang.setindex!!(evaluation_env, value, vn)
511-
else
512-
dist = node_function(evaluation_env, loop_vars)
513-
if !is_observed
514-
l = var_lengths[vn]
515-
if model.transformed
516-
b = Bijectors.bijector(dist)
517-
b_inv = Bijectors.inverse(b)
518-
reconstructed_value = reconstruct(
519-
b_inv,
520-
dist,
521-
view(flattened_values, current_idx:(current_idx + l - 1)),
522-
)
523-
value, logjac = Bijectors.with_logabsdet_jacobian(
524-
b_inv, reconstructed_value
525-
)
526-
else
527-
value = reconstruct(
528-
dist, view(flattened_values, current_idx:(current_idx + l - 1))
529-
)
530-
logjac = 0.0
531-
end
532-
current_idx += l
533-
logp += logpdf(dist, value) + logjac
534-
evaluation_env = BangBang.setindex!!(evaluation_env, value, vn)
535-
else
536-
logp += logpdf(dist, AbstractPPL.get(evaluation_env, vn))
537-
end
538-
end
539-
end
540-
return evaluation_env, logp
494+
evaluation_env, (logprior, loglikelihood, tempered_logjoint) = _tempered_evaluate!!(
495+
model, flattened_values; temperature=1.0
496+
)
497+
return evaluation_env, tempered_logjoint
541498
end
542499

543500
"""

0 commit comments

Comments
 (0)