Skip to content

Commit 1d47fe2

Browse files
authored
Cleanup some dead code (#243)
1 parent f07d31b commit 1d47fe2

File tree

2 files changed

+14
-4
lines changed

2 files changed

+14
-4
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "JuliaBUGS"
22
uuid = "ba9fb4c0-828e-4473-b6a1-cd2560fee5bf"
3-
version = "0.7.1"
3+
version = "0.7.2"
44

55
[deps]
66
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"

src/model.jl

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -431,9 +431,8 @@ function check_var_group(var_group::Vector{<:VarName}, model::BUGSModel)
431431
end
432432

433433
function AbstractPPL.evaluate!!(rng::Random.AbstractRNG, model::BUGSModel)
434-
(; evaluation_env, g) = model
435-
vi = deepcopy(evaluation_env)
436434
logp = 0.0
435+
evaluation_env = deepcopy(model.evaluation_env)
437436
for (i, vn) in enumerate(model.flattened_graph_node_data.sorted_nodes)
438437
is_stochastic = model.flattened_graph_node_data.is_stochastic_vals[i]
439438
node_function = model.flattened_graph_node_data.node_function_vals[i]
@@ -444,7 +443,16 @@ function AbstractPPL.evaluate!!(rng::Random.AbstractRNG, model::BUGSModel)
444443
else
445444
dist = node_function(model.evaluation_env, loop_vars)
446445
value = rand(rng, dist) # just sample from the prior
447-
logp += logpdf(dist, value)
446+
if model.transformed
447+
# see below for why we need to transform the value
448+
value_transformed = Bijectors.transform(Bijectors.bijector(dist), value)
449+
logp +=
450+
Distributions.logpdf(dist, value) + Bijectors.logabsdetjac(
451+
Bijectors.inverse(Bijectors.bijector(dist)), value_transformed
452+
)
453+
else
454+
logp += Distributions.logpdf(dist, value)
455+
end
448456
evaluation_env = setindex!!(evaluation_env, value, vn)
449457
end
450458
end
@@ -467,6 +475,8 @@ function AbstractPPL.evaluate!!(model::BUGSModel)
467475
if model.transformed
468476
# although the values stored in `evaluation_env` are in their original space,
469477
# here we behave as accepting a vector of parameters in the transformed space
478+
# this is so that we have consistent logp values between
479+
# (1) set values in original space then evaluate (2) directly evaluate with the values in transformed space
470480
value_transformed = Bijectors.transform(Bijectors.bijector(dist), value)
471481
logp +=
472482
Distributions.logpdf(dist, value) + Bijectors.logabsdetjac(

0 commit comments

Comments
 (0)