Skip to content

Commit 1110d30

Browse files
Fixed incorrect calls to to_linked_internal_transform (#726)
* fixed calls to `to_linked_internal_transform` * fixed incorrect call to `acclogp_assume!!` * added test for the branch we were currently imssing * formatting Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
1 parent ba6e3b8 commit 1110d30

File tree

2 files changed

+46
-3
lines changed

2 files changed

+46
-3
lines changed

src/context_implementations.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,7 @@ function assume(
219219
else
220220
r = init(rng, dist, sampler)
221221
if istrans(vi)
222-
f = to_linked_internal_transform(vi, dist)
222+
f = to_linked_internal_transform(vi, vn, dist)
223223
push!!(vi, vn, f(r), dist, sampler)
224224
# By default `push!!` sets the transformed flag to `false`.
225225
settrans!!(vi, true, vn)
@@ -401,7 +401,7 @@ end
401401
# HACK: These methods are only used in the `get_and_set_val!` methods below.
402402
# FIXME: Remove these.
403403
function _link_broadcast_new(vi, vn, dist, r)
404-
b = to_linked_internal_transform(vi, dist)
404+
b = to_linked_internal_transform(vi, vn, dist)
405405
return b(r)
406406
end
407407

@@ -492,7 +492,7 @@ function get_and_set_val!(
492492
push!!.((vi,), vns, _link_broadcast_new.((vi,), vns, dists, r), dists, (spl,))
493493
# NOTE: Need to add the correction.
494494
# FIXME: This is not great.
495-
acclogp_assume!!(vi, sum(logabsdetjac.(link_transform.(dists), r)))
495+
acclogp!!(vi, sum(logabsdetjac.(link_transform.(dists), r)))
496496
# `push!!` sets the trans-flag to `false` by default.
497497
settrans!!.((vi,), true, vns)
498498
else

test/varinfo.jl

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -770,4 +770,47 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,)
770770
@test any(!Base.Fix1(DynamicPPL.istrans, varinfo_linked), vns_m)
771771
end
772772
end
773+
774+
@testset "sampling from linked varinfo" begin
775+
# `~`
776+
@model function demo(n=1)
777+
x = Vector(undef, n)
778+
for i in eachindex(x)
779+
x[i] ~ Exponential()
780+
end
781+
return x
782+
end
783+
model1 = demo(1)
784+
varinfo1 = DynamicPPL.link!!(VarInfo(model1), model1)
785+
# Sampling from `model2` should hit the `istrans(vi) == true` branches
786+
# because all the existing variables are linked.
787+
model2 = demo(2)
788+
varinfo2 = last(
789+
DynamicPPL.evaluate!!(model2, deepcopy(varinfo1), SamplingContext())
790+
)
791+
for vn in [@varname(x[1]), @varname(x[2])]
792+
@test DynamicPPL.istrans(varinfo2, vn)
793+
end
794+
795+
# `.~`
796+
@model function demo_dot(n=1)
797+
x ~ Exponential()
798+
if n > 1
799+
y = Vector(undef, n - 1)
800+
y .~ Exponential()
801+
end
802+
return x
803+
end
804+
model1 = demo_dot(1)
805+
varinfo1 = DynamicPPL.link!!(DynamicPPL.untyped_varinfo(model1), model1)
806+
# Sampling from `model2` should hit the `istrans(vi) == true` branches
807+
# because all the existing variables are linked.
808+
model2 = demo_dot(2)
809+
varinfo2 = last(
810+
DynamicPPL.evaluate!!(model2, deepcopy(varinfo1), SamplingContext())
811+
)
812+
for vn in [@varname(x), @varname(y[1])]
813+
@test DynamicPPL.istrans(varinfo2, vn)
814+
end
815+
end
773816
end

0 commit comments

Comments
 (0)