Skip to content

Commit bb0c857

Browse files
Add missing tovec methods for NamedTuple and Tuple. (#939)
* Add missing `tovec` methods for `NamedTuple` and `Tuple`. * add a test for `Tuple` and formatting * Update test/model.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
1 parent a8a7026 commit bb0c857

File tree

5 files changed

+25
-1
lines changed

5 files changed

+25
-1
lines changed

HISTORY.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,10 @@
11
# DynamicPPL Changelog
22

3+
## 0.36.9
4+
5+
Fixed a failure when sampling from `ProductNamedTupleDistribution` due to
6+
missing `tovec` methods for `NamedTuple` and `Tuple`.
7+
38
## 0.36.8
49

510
Made `LogDensityFunction` a subtype of `AbstractMCMC.AbstractModel`.

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "DynamicPPL"
22
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
3-
version = "0.36.8"
3+
version = "0.36.9"
44

55
[deps]
66
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

src/utils.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -443,6 +443,8 @@ to_linked_vec_transform(x) = inverse(from_linked_vec_transform(x))
443443
# or fix `tovec` to flatten the full matrix instead of using `Bijectors.triu_to_vec`.
444444
tovec(x::Real) = [x]
445445
tovec(x::AbstractArray) = vec(x)
446+
tovec(t::Tuple) = mapreduce(tovec, vcat, t)
447+
tovec(nt::NamedTuple) = mapreduce(tovec, vcat, values(nt))
446448
tovec(C::Cholesky) = tovec(Matrix(C.UL))
447449

448450
"""

test/model.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -617,4 +617,15 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal()
617617
]) 1.0 * xs_test[2] rtol = 0.1
618618
end
619619
end
620+
621+
@testset "ProductNamedTupleDistribution sampling" begin
622+
priors = (a=Normal(), b=Normal())
623+
d = product_distribution(priors)
624+
@model function sample_nt(priors_dist)
625+
x ~ priors_dist
626+
return x
627+
end
628+
model = sample_nt(d)
629+
@test model() isa NamedTuple
630+
end
620631
end

test/utils.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,12 @@
4747
dist = LKJCholesky(2, 1)
4848
x = rand(dist)
4949
@test DynamicPPL.tovec(x) == vec(x.UL)
50+
51+
nt = (a=[1, 2], b=3.0)
52+
@test DynamicPPL.tovec(nt) == [1, 2, 3.0]
53+
54+
t = (2.0, [3.0, 4.0])
55+
@test DynamicPPL.tovec(t) == [2.0, 3.0, 4.0]
5056
end
5157

5258
@testset "unique_syms" begin

0 commit comments

Comments
 (0)