Skip to content

Commit 646b9d7

Browse files
committed
Add hasvalue for (some) distributions
1 parent 56b9f2c commit 646b9d7

File tree

5 files changed

+160
-91
lines changed

5 files changed

+160
-91
lines changed

Project.toml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,17 @@ JSON = "682c06a0-de6a-54ab-a142-c8b1cf79cde6"
1313
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1414
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
1515

16+
[weakdeps]
17+
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
18+
19+
[extensions]
20+
AbstractPPLDistributionsExt = ["Distributions"]
21+
1622
[compat]
1723
AbstractMCMC = "2, 3, 4, 5"
1824
Accessors = "0.1"
1925
DensityInterface = "0.4"
26+
Distributions = "0.25"
2027
JSON = "0.19 - 0.21"
2128
Random = "1.6"
2229
StatsBase = "0.32, 0.33, 0.34"

ext/AbstractPPLDistributionsExt.jl

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
module AbstractPPLDistributionsExt
2+
3+
if isdefined(Base, :get_extension)
4+
using AbstractPPL: AbstractPPL, VarName, Accessors
5+
using Distributions: Distributions
6+
else
7+
using ..AbstractPPL: AbstractPPL, VarName, Accessors
8+
using ..Distributions: Distributions
9+
end
10+
11+
# TODO(penelopeysm): Figure out tuple / namedtuple distributions, and LKJCholesky (grr)
12+
function AbstractPPL.hasvalue(
13+
vals::AbstractDict, vn::VarName, dist::Distributions.Distribution
14+
)
15+
@warn "`hasvalue(vals, vn, dist)` is not implemented for $(typeof(dist)); falling back to `hasvalue(vals, vn)`."
16+
return AbstractPPL.hasvalue(vals, vn)
17+
end
18+
function AbstractPPL.hasvalue(
19+
vals::AbstractDict, vn::VarName, ::Distributions.UnivariateDistribution
20+
)
21+
return AbstractPPL.hasvalue(vals, vn)
22+
end
23+
function AbstractPPL.hasvalue(
24+
vals::AbstractDict{<:VarName},
25+
vn::VarName{sym},
26+
dist::Union{Distributions.MultivariateDistribution,Distributions.MatrixDistribution},
27+
) where {sym}
28+
# If `vn` is present as-is, then we are good
29+
AbstractPPL.hasvalue(vals, vn) && return true
30+
# If not, then we need to check inside `vals` to see if a subset of
31+
# `vals` is enough to reconstruct `vn`. For example, if `vals` contains
32+
# `x[1]` and `x[2]`, and `dist` is `MvNormal(zeros(2), I)`, then we
33+
# can reconstruct `x`. If `dist` is `MvNormal(zeros(3), I)`, then we
34+
# can't.
35+
# To do this, we get the size of the distribution and iterate over all
36+
# possible indices. If every index can be found in `subsumed_keys`, then we
37+
# can return true.
38+
sz = size(dist)
39+
for idx in Iterators.product(map(Base.OneTo, sz)...)
40+
new_optic = if AbstractPPL.getoptic(vn) === identity
41+
Accessors.IndexLens(idx)
42+
else
43+
Accessors.IndexLens(idx) AbstractPPL.getoptic(vn)
44+
end
45+
new_vn = VarName{sym}(new_optic)
46+
AbstractPPL.hasvalue(vals, new_vn) || return false
47+
end
48+
return true
49+
end
50+
51+
end

src/hasvalue.jl

Lines changed: 44 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ julia> getvalue(vals, @varname(x[1])) # different from `getindex`
5757
1.0
5858
5959
julia> getvalue(vals, @varname(x[2]))
60-
ERROR: BoundsError: attempt to access 1-element Vector{Float64} at index [2]
60+
ERROR: getvalue: x[2] was not found in the values provided
6161
[...]
6262
```
6363
@@ -74,7 +74,7 @@ julia> getvalue(vals, @varname(x[1])) # different from `getindex`
7474
1.0
7575
7676
julia> getvalue(vals, @varname(x[2]))
77-
ERROR: BoundsError: attempt to access 1-element Vector{Float64} at index [2]
77+
ERROR: getvalue: x[2] was not found in the values provided
7878
[...]
7979
```
8080
@@ -91,16 +91,53 @@ julia> getvalue(vals, @varname(x[1][1])) # different from `getindex`
9191
1.0
9292
9393
julia> getvalue(vals, @varname(x[1][2]))
94-
ERROR: BoundsError: attempt to access 1-element Vector{Float64} at index [2]
94+
ERROR: getvalue: x[1][2] was not found in the values provided
9595
[...]
9696
9797
julia> getvalue(vals, @varname(x[2][1]))
98-
ERROR: KeyError: key x[2][1] not found
98+
ERROR: getvalue: x[2][1] was not found in the values provided
9999
[...]
100100
```
101101
"""
102-
getvalue(vals::NamedTuple, vn::VarName) = get(vals, vn)
103-
getvalue(vals::AbstractDict, vn::VarName) = nested_getindex(vals, vn)
102+
function getvalue(vals::NamedTuple, vn::VarName{sym}) where {sym}
103+
optic = getoptic(vn)
104+
if haskey(vals, sym) && canview(optic, getproperty(vals, sym))
105+
return optic(vals[sym])
106+
else
107+
error("getvalue: $(vn) was not found in the values provided")
108+
end
109+
end
110+
# For the Dict case, it is more complicated. There are two cases:
111+
# 1. `vn` itself is already a key of `vals` (the easy case)
112+
# 2. `vn` is not a key of `vals`, but some parent of `vn` is a key of `vals`
113+
# (the harder case). For example, if `vn` is `x[1][2]`, then we need to
114+
# check if either `x` or `x[1]` is a key of `vals`, and if so, whether
115+
# we can index into the corresponding value.
116+
function getvalue(vals::AbstractDict{<:VarName}, vn::VarName{sym}) where {sym}
117+
# First we check if `vn` is present as is.
118+
haskey(vals, vn) && return vals[vn]
119+
120+
# Otherwise, we start by testing the bare `vn` (e.g., if `vn` is `x[1][2]`,
121+
# we start by checking if `x` is present). We will then keep adding optics
122+
# to `test_optic`, either until we find a key that is present, or until we
123+
# run out of optics to test (which is determined by _inner(test_optic) ==
124+
# identity).
125+
test_vn = VarName{sym}()
126+
test_optic = getoptic(vn)
127+
128+
while _inner(test_optic) != identity
129+
if haskey(vals, test_vn) && canview(test_optic, vals[test_vn])
130+
return test_optic(vals[test_vn])
131+
else
132+
# Move the innermost optic into test_vn
133+
test_optic_outer = _outer(test_optic)
134+
test_optic_inner = _inner(test_optic)
135+
test_vn = VarName{sym}(test_optic_inner getoptic(test_vn))
136+
test_optic = test_optic_outer
137+
end
138+
end
139+
return error("getvalue: $(vn) was not found in the values provided")
140+
end
104141

105142
"""
106143
hasvalue(vals::NamedTuple, vn::VarName)
@@ -168,13 +205,6 @@ false
168205
function hasvalue(vals::NamedTuple, vn::VarName{sym}) where {sym}
169206
return haskey(vals, sym) && canview(getoptic(vn), getproperty(vals, sym))
170207
end
171-
172-
# For the Dict case, it is more complicated. There are two cases:
173-
# 1. `vn` itself is already a key of `vals` (the easy case)
174-
# 2. `vn` is not a key of `vals`, but some parent of `vn` is a key of `vals`
175-
# (the harder case). For example, if `vn` is `x[1][2]`, then we need to
176-
# check if either `x` or `x[1]` is a key of `vals`, and if so, whether
177-
# we can index into the corresponding value.
178208
function hasvalue(vals::AbstractDict{<:VarName}, vn::VarName{sym}) where {sym}
179209
# First we check if `vn` is present as is.
180210
haskey(vals, vn) && return true
@@ -186,12 +216,8 @@ function hasvalue(vals::AbstractDict{<:VarName}, vn::VarName{sym}) where {sym}
186216
# identity).
187217
test_vn = VarName{sym}()
188218
test_optic = getoptic(vn)
189-
219+
190220
while _inner(test_optic) != identity
191-
@show test_vn, test_optic
192-
if haskey(vals, test_vn)
193-
@show canview(test_optic, vals[test_vn])
194-
end
195221
if haskey(vals, test_vn) && canview(test_optic, vals[test_vn])
196222
return true
197223
else
@@ -204,68 +230,3 @@ function hasvalue(vals::AbstractDict{<:VarName}, vn::VarName{sym}) where {sym}
204230
end
205231
return false
206232
end
207-
# TODO(penelopeysm): Figure out tuple / namedtuple distributions, and LKJCholesky (grr)
208-
# function hasvalue(vals::AbstractDict, vn::VarName, dist::Distribution)
209-
# @warn "`hasvalue(vals, vn, dist)` is not implemented for $(typeof(dist)); falling back to `hasvalue(vals, vn)`."
210-
# return hasvalue(vals, vn)
211-
# end
212-
# hasvalue(vals::AbstractDict, vn::VarName, ::UnivariateDistribution) = hasvalue(vals, vn)
213-
# function hasvalue(
214-
# vals::AbstractDict{<:VarName},
215-
# vn::VarName{sym},
216-
# dist::Union{MultivariateDistribution,MatrixDistribution},
217-
# ) where {sym}
218-
# # If `vn` is present as-is, then we are good
219-
# hasvalue(vals, vn) && return true
220-
# # If not, then we need to check inside `vals` to see if a subset of
221-
# # `vals` is enough to reconstruct `vn`. For example, if `vals` contains
222-
# # `x[1]` and `x[2]`, and `dist` is `MvNormal(zeros(2), I)`, then we
223-
# # can reconstruct `x`. If `dist` is `MvNormal(zeros(3), I)`, then we
224-
# # can't.
225-
# # To do this, we get the size of the distribution and iterate over all
226-
# # possible indices. If every index can be found in `subsumed_keys`, then we
227-
# # can return true.
228-
# sz = size(dist)
229-
# for idx in Iterators.product(map(Base.OneTo, sz)...)
230-
# new_optic = if getoptic(vn) === identity
231-
# Accessors.IndexLens(idx)
232-
# else
233-
# Accessors.IndexLens(idx) ∘ getoptic(vn)
234-
# end
235-
# new_vn = VarName{sym}(new_optic)
236-
# hasvalue(vals, new_vn) || return false
237-
# end
238-
# return true
239-
# end
240-
241-
# """
242-
# nested_getindex(values::AbstractDict, vn::VarName)
243-
#
244-
# Return value corresponding to `vn` in `values` by also looking
245-
# in the the actual values of the dict.
246-
# """
247-
# function nested_getindex(values::AbstractDict, vn::VarName)
248-
# maybeval = get(values, vn, nothing)
249-
# if maybeval !== nothing
250-
# return maybeval
251-
# end
252-
#
253-
# # Split the optic into the key / `parent` and the extraction optic / `child`.
254-
# parent, child, issuccess = splitoptic(getoptic(vn)) do optic
255-
# o = optic === nothing ? identity : optic
256-
# haskey(values, VarName(vn, o))
257-
# end
258-
# # When combined with `VarInfo`, `nothing` is equivalent to `identity`.
259-
# keyoptic = parent === nothing ? identity : parent
260-
#
261-
# # If we found a valid split, then we can extract the value.
262-
# if !issuccess
263-
# # At this point we just throw an error since the key could not be found.
264-
# throw(KeyError(vn))
265-
# end
266-
#
267-
# # TODO: Should we also check that we `canview` the extracted `value`
268-
# # rather than just let it fail upon `get` call?
269-
# value = values[VarName(vn, keyoptic)]
270-
# return child(value)
271-
# end

test/Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
[deps]
22
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
3+
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
34
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
45
InvertedIndices = "41ab1584-1d38-5bbf-9106-f11c6c58b48f"
6+
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
57
OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"
68
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
79
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

test/hasvalue.jl

Lines changed: 56 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,28 @@
1-
@testset "hasvalue" begin
1+
@testset "base getvalue + hasvalue" begin
22
@testset "NamedTuple" begin
33
nt = (a=[1], b=2, c=(x=3,), d=[1.0 0.5; 0.5 1.0])
44
@test hasvalue(nt, @varname(a))
5+
@test getvalue(nt, @varname(a)) == [1]
56
@test hasvalue(nt, @varname(a[1]))
7+
@test getvalue(nt, @varname(a[1])) == 1
68
@test hasvalue(nt, @varname(b))
9+
@test getvalue(nt, @varname(b)) == 2
710
@test hasvalue(nt, @varname(c))
11+
@test getvalue(nt, @varname(c)) == (x=3,)
812
@test hasvalue(nt, @varname(c.x))
13+
@test getvalue(nt, @varname(c.x)) == 3
914
@test hasvalue(nt, @varname(d))
15+
@test getvalue(nt, @varname(d)) == [1.0 0.5; 0.5 1.0]
1016
@test hasvalue(nt, @varname(d[1, 1]))
17+
@test getvalue(nt, @varname(d[1, 1])) == 1.0
1118
@test hasvalue(nt, @varname(d[1, 2]))
19+
@test getvalue(nt, @varname(d[1, 2])) == 0.5
1220
@test hasvalue(nt, @varname(d[2, 1]))
21+
@test getvalue(nt, @varname(d[2, 1])) == 0.5
1322
@test hasvalue(nt, @varname(d[2, 2]))
23+
@test getvalue(nt, @varname(d[2, 2])) == 1.0
1424
@test hasvalue(nt, @varname(d[3])) # linear indexing works....
25+
@test getvalue(nt, @varname(d[3])) == 0.5
1526
@test !hasvalue(nt, @varname(nope))
1627
@test !hasvalue(nt, @varname(a[2]))
1728
@test !hasvalue(nt, @varname(a[1][1]))
@@ -23,40 +34,77 @@
2334

2435
@testset "Dict" begin
2536
# same tests as above
26-
d = Dict(@varname(a) => [1],
37+
d = Dict(
38+
@varname(a) => [1],
2739
@varname(b) => 2,
2840
@varname(c) => (x=3,),
29-
@varname(d) => [1.0 0.5; 0.5 1.0])
41+
@varname(d) => [1.0 0.5; 0.5 1.0],
42+
)
3043
@test hasvalue(d, @varname(a))
44+
@test getvalue(d, @varname(a)) == [1]
3145
@test hasvalue(d, @varname(a[1]))
46+
@test getvalue(d, @varname(a[1])) == 1
3247
@test hasvalue(d, @varname(b))
48+
@test getvalue(d, @varname(b)) == 2
3349
@test hasvalue(d, @varname(c))
50+
@test getvalue(d, @varname(c)) == (x=3,)
3451
@test hasvalue(d, @varname(c.x))
52+
@test getvalue(d, @varname(c.x)) == 3
3553
@test hasvalue(d, @varname(d))
54+
@test getvalue(d, @varname(d)) == [1.0 0.5; 0.5 1.0]
3655
@test hasvalue(d, @varname(d[1, 1]))
56+
@test getvalue(d, @varname(d[1, 1])) == 1.0
3757
@test hasvalue(d, @varname(d[1, 2]))
58+
@test getvalue(d, @varname(d[1, 2])) == 0.5
3859
@test hasvalue(d, @varname(d[2, 1]))
60+
@test getvalue(d, @varname(d[2, 1])) == 0.5
3961
@test hasvalue(d, @varname(d[2, 2]))
40-
@test hasvalue(d, @varname(d[3])) # linear indexing works....
62+
@test getvalue(d, @varname(d[2, 2])) == 1.0
63+
@test hasvalue(d, @varname(d[3])) # linear indexing works....
64+
@test getvalue(d, @varname(d[3])) == 0.5
4165
@test !hasvalue(d, @varname(nope))
4266
@test !hasvalue(d, @varname(a[2]))
4367
@test !hasvalue(d, @varname(a[1][1]))
4468
@test !hasvalue(d, @varname(c.x[1]))
4569
@test !hasvalue(d, @varname(c.y))
4670
@test !hasvalue(d, @varname(d[1, 3]))
47-
@test !hasvalue(d, @varname(d[3]))
4871

49-
# extra ones since Dict can have weird key
50-
d = Dict(@varname(a[1]) => [1.0, 2.0],
51-
@varname(b.x) => [3.0])
72+
# extra ones since Dict can have weird keys
73+
d = Dict(
74+
@varname(a[1]) => [1.0, 2.0],
75+
@varname(b.x) => [3.0],
76+
@varname(c[2]) => (a=4.0, b=5.0),
77+
)
5278
@test hasvalue(d, @varname(a[1]))
79+
@test getvalue(d, @varname(a[1])) == [1.0, 2.0]
5380
@test hasvalue(d, @varname(a[1][1]))
81+
@test getvalue(d, @varname(a[1][1])) == 1.0
5482
@test hasvalue(d, @varname(a[1][2]))
83+
@test getvalue(d, @varname(a[1][2])) == 2.0
5584
@test hasvalue(d, @varname(b.x))
85+
@test getvalue(d, @varname(b.x)) == [3.0]
5686
@test hasvalue(d, @varname(b.x[1]))
87+
@test getvalue(d, @varname(b.x[1])) == 3.0
88+
@test hasvalue(d, @varname(c[2]))
89+
@test getvalue(d, @varname(c[2])) == (a=4.0, b=5.0)
90+
@test hasvalue(d, @varname(c[2].a))
91+
@test getvalue(d, @varname(c[2].a)) == 4.0
92+
@test hasvalue(d, @varname(c[2].b))
93+
@test getvalue(d, @varname(c[2].b)) == 5.0
5794
@test !hasvalue(d, @varname(a))
5895
@test !hasvalue(d, @varname(a[2]))
5996
@test !hasvalue(d, @varname(b.y))
6097
@test !hasvalue(d, @varname(b.x[2]))
98+
@test !hasvalue(d, @varname(c[1]))
99+
@test !hasvalue(d, @varname(c[2].x))
61100
end
62101
end
102+
103+
@testset "with Distributions: getvalue + hasvalue" begin
104+
using Distributions
105+
using LinearAlgebra
106+
107+
d = Dict(@varname(x[1]) => 1.0, @varname(x[2]) => 2.0)
108+
@test hasvalue(d, @varname(x), MvNormal(zeros(2), I))
109+
@test !hasvalue(d, @varname(x), MvNormal(zeros(3), I))
110+
end

0 commit comments

Comments
 (0)