Skip to content

Commit 4bc43a4

Browse files
Fix for #842 (#843)
* fix for #842 * bump patch version * Update src/varinfo.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
1 parent 5c89efc commit 4bc43a4

File tree

3 files changed

+17
-2
lines changed

3 files changed

+17
-2
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "DynamicPPL"
22
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
3-
version = "0.35.1"
3+
version = "0.35.2"
44

55
[deps]
66
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

src/varinfo.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,11 @@ function unflatten(vi::VarInfo, x::AbstractVector)
215215
md = unflatten_metadata(vi.metadata, x)
216216
# Note that use of RefValue{eltype(x)} rather than Ref is necessary to deal with cases
217217
# where e.g. x is a type gradient of some AD backend.
218-
return VarInfo(md, Base.RefValue{eltype(x)}(getlogp(vi)), Ref(get_num_produce(vi)))
218+
return VarInfo(
219+
md,
220+
Base.RefValue{float_type_with_fallback(eltype(x))}(getlogp(vi)),
221+
Ref(get_num_produce(vi)),
222+
)
219223
end
220224

221225
# We would call this `unflatten` if not for `unflatten` having a method for NamedTuples in

test/varinfo.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1017,4 +1017,15 @@ end
10171017
@test vi.metadata.b.orders == [2]
10181018
@test DynamicPPL.get_num_produce(vi) == 3
10191019
end
1020+
1021+
@testset "issue #842" begin
1022+
model = DynamicPPL.TestUtils.DEMO_MODELS[1]
1023+
varinfo = VarInfo(model)
1024+
1025+
n = length(varinfo[:])
1026+
# `Bool`.
1027+
@test getlogp(DynamicPPL.unflatten(varinfo, fill(true, n))) isa typeof(float(1))
1028+
# `Int`.
1029+
@test getlogp(DynamicPPL.unflatten(varinfo, fill(1, n))) isa typeof(float(1))
1030+
end
10201031
end

0 commit comments

Comments
 (0)