Skip to content

Commit 56b9f2c

Browse files
committed
Move hasvalue and getvalue to AbstractPPL; reimplement
1 parent 7be9556 commit 56b9f2c

File tree

4 files changed

+338
-1
lines changed

4 files changed

+338
-1
lines changed

src/AbstractPPL.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@ export VarName,
1616
varname_to_string,
1717
string_to_varname,
1818
prefix,
19-
unprefix
19+
unprefix,
20+
getvalue,
21+
hasvalue
2022

2123
# Abstract model functions
2224
export AbstractProbabilisticProgram,
@@ -29,5 +31,6 @@ include("varname.jl")
2931
include("abstractmodeltrace.jl")
3032
include("abstractprobprog.jl")
3133
include("evaluate.jl")
34+
include("hasvalue.jl")
3235

3336
end # module

src/hasvalue.jl

Lines changed: 271 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,271 @@
1+
"""
2+
canview(optic, container)
3+
4+
Return `true` if `optic` can be used to view `container`, and `false` otherwise.
5+
6+
# Examples
7+
```jldoctest; setup=:(using Accessors)
8+
julia> AbstractPPL.canview(@o(_.a), (a = 1.0, ))
9+
true
10+
11+
julia> AbstractPPL.canview(@o(_.a), (b = 1.0, )) # property `a` does not exist
12+
false
13+
14+
julia> AbstractPPL.canview(@o(_.a[1]), (a = [1.0, 2.0], ))
15+
true
16+
17+
julia> AbstractPPL.canview(@o(_.a[3]), (a = [1.0, 2.0], )) # out of bounds
18+
false
19+
```
20+
"""
21+
canview(optic, container) = false
22+
canview(::typeof(identity), _) = true
23+
function canview(::Accessors.PropertyLens{field}, x) where {field}
24+
return hasproperty(x, field)
25+
end
26+
27+
# `IndexLens`: only relevant if `x` supports indexing.
28+
canview(optic::Accessors.IndexLens, x) = false
29+
function canview(optic::Accessors.IndexLens, x::AbstractArray)
30+
return checkbounds(Bool, x, optic.indices...)
31+
end
32+
33+
# `ComposedFunction`: check that we can view `.inner` and `.outer`, but using
34+
# value extracted using `.inner`.
35+
function canview(optic::ComposedFunction, x)
36+
return canview(optic.inner, x) && canview(optic.outer, optic.inner(x))
37+
end
38+
39+
"""
40+
getvalue(vals::NamedTuple, vn::VarName)
41+
getvalue(vals::AbstractDict{<:VarName}, vn::VarName)
42+
43+
Return the value(s) in `vals` represented by `vn`.
44+
45+
# Examples
46+
47+
For `NamedTuple`:
48+
49+
```jldoctest
50+
julia> vals = (x = [1.0],);
51+
52+
julia> getvalue(vals, @varname(x)) # same as `getindex`
53+
1-element Vector{Float64}:
54+
1.0
55+
56+
julia> getvalue(vals, @varname(x[1])) # different from `getindex`
57+
1.0
58+
59+
julia> getvalue(vals, @varname(x[2]))
60+
ERROR: BoundsError: attempt to access 1-element Vector{Float64} at index [2]
61+
[...]
62+
```
63+
64+
For `AbstractDict`:
65+
66+
```jldoctest
67+
julia> vals = Dict(@varname(x) => [1.0]);
68+
69+
julia> getvalue(vals, @varname(x)) # same as `getindex`
70+
1-element Vector{Float64}:
71+
1.0
72+
73+
julia> getvalue(vals, @varname(x[1])) # different from `getindex`
74+
1.0
75+
76+
julia> getvalue(vals, @varname(x[2]))
77+
ERROR: BoundsError: attempt to access 1-element Vector{Float64} at index [2]
78+
[...]
79+
```
80+
81+
In the `AbstractDict` case we can also have keys such as `v[1]`:
82+
83+
```jldoctest
84+
julia> vals = Dict(@varname(x[1]) => [1.0,]);
85+
86+
julia> getvalue(vals, @varname(x[1])) # same as `getindex`
87+
1-element Vector{Float64}:
88+
1.0
89+
90+
julia> getvalue(vals, @varname(x[1][1])) # different from `getindex`
91+
1.0
92+
93+
julia> getvalue(vals, @varname(x[1][2]))
94+
ERROR: BoundsError: attempt to access 1-element Vector{Float64} at index [2]
95+
[...]
96+
97+
julia> getvalue(vals, @varname(x[2][1]))
98+
ERROR: KeyError: key x[2][1] not found
99+
[...]
100+
```
101+
"""
102+
getvalue(vals::NamedTuple, vn::VarName) = get(vals, vn)
103+
getvalue(vals::AbstractDict, vn::VarName) = nested_getindex(vals, vn)
104+
105+
"""
106+
hasvalue(vals::NamedTuple, vn::VarName)
107+
hasvalue(vals::AbstractDict{<:VarName}, vn::VarName)
108+
109+
Determine whether `vals` contains a value for a given `vn`.
110+
111+
# Examples
112+
With `x` as a `NamedTuple`:
113+
114+
```jldoctest
115+
julia> hasvalue((x = 1.0, ), @varname(x))
116+
true
117+
118+
julia> hasvalue((x = 1.0, ), @varname(x[1]))
119+
false
120+
121+
julia> hasvalue((x = [1.0],), @varname(x))
122+
true
123+
124+
julia> hasvalue((x = [1.0],), @varname(x[1]))
125+
true
126+
127+
julia> hasvalue((x = [1.0],), @varname(x[2]))
128+
false
129+
```
130+
131+
With `x` as a `AbstractDict`:
132+
133+
```jldoctest
134+
julia> hasvalue(Dict(@varname(x) => 1.0, ), @varname(x))
135+
true
136+
137+
julia> hasvalue(Dict(@varname(x) => 1.0, ), @varname(x[1]))
138+
false
139+
140+
julia> hasvalue(Dict(@varname(x) => [1.0]), @varname(x))
141+
true
142+
143+
julia> hasvalue(Dict(@varname(x) => [1.0]), @varname(x[1]))
144+
true
145+
146+
julia> hasvalue(Dict(@varname(x) => [1.0]), @varname(x[2]))
147+
false
148+
```
149+
150+
In the `AbstractDict` case we can also have keys such as `v[1]`:
151+
152+
```jldoctest
153+
julia> vals = Dict(@varname(x[1]) => [1.0,]);
154+
155+
julia> hasvalue(vals, @varname(x[1])) # same as `haskey`
156+
true
157+
158+
julia> hasvalue(vals, @varname(x[1][1])) # different from `haskey`
159+
true
160+
161+
julia> hasvalue(vals, @varname(x[1][2]))
162+
false
163+
164+
julia> hasvalue(vals, @varname(x[2][1]))
165+
false
166+
```
167+
"""
168+
function hasvalue(vals::NamedTuple, vn::VarName{sym}) where {sym}
169+
return haskey(vals, sym) && canview(getoptic(vn), getproperty(vals, sym))
170+
end
171+
172+
# For the Dict case, it is more complicated. There are two cases:
173+
# 1. `vn` itself is already a key of `vals` (the easy case)
174+
# 2. `vn` is not a key of `vals`, but some parent of `vn` is a key of `vals`
175+
# (the harder case). For example, if `vn` is `x[1][2]`, then we need to
176+
# check if either `x` or `x[1]` is a key of `vals`, and if so, whether
177+
# we can index into the corresponding value.
178+
function hasvalue(vals::AbstractDict{<:VarName}, vn::VarName{sym}) where {sym}
179+
# First we check if `vn` is present as is.
180+
haskey(vals, vn) && return true
181+
182+
# Otherwise, we start by testing the bare `vn` (e.g., if `vn` is `x[1][2]`,
183+
# we start by checking if `x` is present). We will then keep adding optics
184+
# to `test_optic`, either until we find a key that is present, or until we
185+
# run out of optics to test (which is determined by _inner(test_optic) ==
186+
# identity).
187+
test_vn = VarName{sym}()
188+
test_optic = getoptic(vn)
189+
190+
while _inner(test_optic) != identity
191+
@show test_vn, test_optic
192+
if haskey(vals, test_vn)
193+
@show canview(test_optic, vals[test_vn])
194+
end
195+
if haskey(vals, test_vn) && canview(test_optic, vals[test_vn])
196+
return true
197+
else
198+
# Move the innermost optic into test_vn
199+
test_optic_outer = _outer(test_optic)
200+
test_optic_inner = _inner(test_optic)
201+
test_vn = VarName{sym}(test_optic_inner getoptic(test_vn))
202+
test_optic = test_optic_outer
203+
end
204+
end
205+
return false
206+
end
207+
# TODO(penelopeysm): Figure out tuple / namedtuple distributions, and LKJCholesky (grr)
208+
# function hasvalue(vals::AbstractDict, vn::VarName, dist::Distribution)
209+
# @warn "`hasvalue(vals, vn, dist)` is not implemented for $(typeof(dist)); falling back to `hasvalue(vals, vn)`."
210+
# return hasvalue(vals, vn)
211+
# end
212+
# hasvalue(vals::AbstractDict, vn::VarName, ::UnivariateDistribution) = hasvalue(vals, vn)
213+
# function hasvalue(
214+
# vals::AbstractDict{<:VarName},
215+
# vn::VarName{sym},
216+
# dist::Union{MultivariateDistribution,MatrixDistribution},
217+
# ) where {sym}
218+
# # If `vn` is present as-is, then we are good
219+
# hasvalue(vals, vn) && return true
220+
# # If not, then we need to check inside `vals` to see if a subset of
221+
# # `vals` is enough to reconstruct `vn`. For example, if `vals` contains
222+
# # `x[1]` and `x[2]`, and `dist` is `MvNormal(zeros(2), I)`, then we
223+
# # can reconstruct `x`. If `dist` is `MvNormal(zeros(3), I)`, then we
224+
# # can't.
225+
# # To do this, we get the size of the distribution and iterate over all
226+
# # possible indices. If every index can be found in `subsumed_keys`, then we
227+
# # can return true.
228+
# sz = size(dist)
229+
# for idx in Iterators.product(map(Base.OneTo, sz)...)
230+
# new_optic = if getoptic(vn) === identity
231+
# Accessors.IndexLens(idx)
232+
# else
233+
# Accessors.IndexLens(idx) ∘ getoptic(vn)
234+
# end
235+
# new_vn = VarName{sym}(new_optic)
236+
# hasvalue(vals, new_vn) || return false
237+
# end
238+
# return true
239+
# end
240+
241+
# """
242+
# nested_getindex(values::AbstractDict, vn::VarName)
243+
#
244+
# Return value corresponding to `vn` in `values` by also looking
245+
# in the the actual values of the dict.
246+
# """
247+
# function nested_getindex(values::AbstractDict, vn::VarName)
248+
# maybeval = get(values, vn, nothing)
249+
# if maybeval !== nothing
250+
# return maybeval
251+
# end
252+
#
253+
# # Split the optic into the key / `parent` and the extraction optic / `child`.
254+
# parent, child, issuccess = splitoptic(getoptic(vn)) do optic
255+
# o = optic === nothing ? identity : optic
256+
# haskey(values, VarName(vn, o))
257+
# end
258+
# # When combined with `VarInfo`, `nothing` is equivalent to `identity`.
259+
# keyoptic = parent === nothing ? identity : parent
260+
#
261+
# # If we found a valid split, then we can extract the value.
262+
# if !issuccess
263+
# # At this point we just throw an error since the key could not be found.
264+
# throw(KeyError(vn))
265+
# end
266+
#
267+
# # TODO: Should we also check that we `canview` the extracted `value`
268+
# # rather than just let it fail upon `get` call?
269+
# value = values[VarName(vn, keyoptic)]
270+
# return child(value)
271+
# end

test/hasvalue.jl

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
@testset "hasvalue" begin
2+
@testset "NamedTuple" begin
3+
nt = (a=[1], b=2, c=(x=3,), d=[1.0 0.5; 0.5 1.0])
4+
@test hasvalue(nt, @varname(a))
5+
@test hasvalue(nt, @varname(a[1]))
6+
@test hasvalue(nt, @varname(b))
7+
@test hasvalue(nt, @varname(c))
8+
@test hasvalue(nt, @varname(c.x))
9+
@test hasvalue(nt, @varname(d))
10+
@test hasvalue(nt, @varname(d[1, 1]))
11+
@test hasvalue(nt, @varname(d[1, 2]))
12+
@test hasvalue(nt, @varname(d[2, 1]))
13+
@test hasvalue(nt, @varname(d[2, 2]))
14+
@test hasvalue(nt, @varname(d[3])) # linear indexing works....
15+
@test !hasvalue(nt, @varname(nope))
16+
@test !hasvalue(nt, @varname(a[2]))
17+
@test !hasvalue(nt, @varname(a[1][1]))
18+
@test !hasvalue(nt, @varname(c.x[1]))
19+
@test !hasvalue(nt, @varname(c.y))
20+
@test !hasvalue(nt, @varname(d[1, 3]))
21+
@test !hasvalue(nt, @varname(d[3, :]))
22+
end
23+
24+
@testset "Dict" begin
25+
# same tests as above
26+
d = Dict(@varname(a) => [1],
27+
@varname(b) => 2,
28+
@varname(c) => (x=3,),
29+
@varname(d) => [1.0 0.5; 0.5 1.0])
30+
@test hasvalue(d, @varname(a))
31+
@test hasvalue(d, @varname(a[1]))
32+
@test hasvalue(d, @varname(b))
33+
@test hasvalue(d, @varname(c))
34+
@test hasvalue(d, @varname(c.x))
35+
@test hasvalue(d, @varname(d))
36+
@test hasvalue(d, @varname(d[1, 1]))
37+
@test hasvalue(d, @varname(d[1, 2]))
38+
@test hasvalue(d, @varname(d[2, 1]))
39+
@test hasvalue(d, @varname(d[2, 2]))
40+
@test hasvalue(d, @varname(d[3])) # linear indexing works....
41+
@test !hasvalue(d, @varname(nope))
42+
@test !hasvalue(d, @varname(a[2]))
43+
@test !hasvalue(d, @varname(a[1][1]))
44+
@test !hasvalue(d, @varname(c.x[1]))
45+
@test !hasvalue(d, @varname(c.y))
46+
@test !hasvalue(d, @varname(d[1, 3]))
47+
@test !hasvalue(d, @varname(d[3]))
48+
49+
# extra ones since Dict can have weird key
50+
d = Dict(@varname(a[1]) => [1.0, 2.0],
51+
@varname(b.x) => [3.0])
52+
@test hasvalue(d, @varname(a[1]))
53+
@test hasvalue(d, @varname(a[1][1]))
54+
@test hasvalue(d, @varname(a[1][2]))
55+
@test hasvalue(d, @varname(b.x))
56+
@test hasvalue(d, @varname(b.x[1]))
57+
@test !hasvalue(d, @varname(a))
58+
@test !hasvalue(d, @varname(a[2]))
59+
@test !hasvalue(d, @varname(b.y))
60+
@test !hasvalue(d, @varname(b.x[2]))
61+
end
62+
end

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ const GROUP = get(ENV, "GROUP", "All")
88
if GROUP == "All" || GROUP == "Tests"
99
include("varname.jl")
1010
include("abstractprobprog.jl")
11+
include("hasvalue.jl")
1112
end
1213

1314
if GROUP == "All" || GROUP == "Doctests"

0 commit comments

Comments
 (0)