Skip to content

Commit 332c64a

Browse files
committed
Make hasvalue and getvalue use the most specific value
1 parent 07975a2 commit 332c64a

File tree

5 files changed

+175
-75
lines changed

5 files changed

+175
-75
lines changed

HISTORY.md

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,16 @@
1+
## 0.12.1
2+
3+
Minimum compatibility has been bumped to Julia 1.10.
4+
5+
Added the new functions `hasvalue(container::T, ::VarName[, ::Distribution])` and `getvalue(container::T, ::VarName[, ::Distribution])`, where `T` is either `NamedTuple` or `AbstractDict{<:VarName}`.
6+
7+
These functions check whether a given `VarName` has a value in the given `NamedTuple` or `AbstractDict`, and return the value if it exists.
8+
9+
The optional `Distribution` argument allows one to reconstruct a full value from its component indices.
10+
For example, if `container` has `x[1]` and `x[2]`, then `hasvalue(container, @varname(x), dist)` will return true if `size(dist) == (2,)` (for example, `MvNormal(zeros(2), I)`).
11+
12+
These functions (without the `Distribution` argument) were previously in DynamicPPL.jl (albeit unexported).
13+
114
## 0.12.0
215

316
### VarName constructors

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ uuid = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf"
33
keywords = ["probablistic programming"]
44
license = "MIT"
55
desc = "Common interfaces for probabilistic programming"
6-
version = "0.12.0"
6+
version = "0.12.1"
77

88
[deps]
99
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"

src/hasvalue.jl

Lines changed: 31 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -117,26 +117,27 @@ function getvalue(vals::AbstractDict{<:VarName}, vn::VarName{sym}) where {sym}
117117
# First we check if `vn` is present as is.
118118
haskey(vals, vn) && return vals[vn]
119119

120-
# Otherwise, we start by testing the bare `vn` (e.g., if `vn` is `x[1][2]`,
121-
# we start by checking if `x` is present). We will then keep adding optics
122-
# to `test_optic`, either until we find a key that is present, or until we
123-
# run out of optics to test (which is determined by _inner(test_optic) ==
124-
# identity).
125-
test_vn = VarName{sym}()
126-
test_optic = getoptic(vn)
127-
128-
while _inner(test_optic) != identity
120+
# Otherwise, we start by testing the `vn` one level up (e.g., if `vn` is
121+
# `x[1][2]`, we start by checking if `x[1]` is present, then `x`). We will
122+
# then keep removing optics from `test_optic`, either until we find a key
123+
# that is present, or until we run out of optics to test (which happens
124+
# after getoptic(test_vn) == identity).
125+
o = getoptic(vn)
126+
test_vn = VarName{sym}(_init(o))
127+
test_optic = _last(o)
128+
129+
while true
129130
if haskey(vals, test_vn) && canview(test_optic, vals[test_vn])
130131
return test_optic(vals[test_vn])
131132
else
132-
# Move the innermost optic into test_vn
133-
test_optic_outer = _outer(test_optic)
134-
test_optic_inner = _inner(test_optic)
135-
test_vn = VarName{sym}(test_optic_inner getoptic(test_vn))
136-
test_optic = test_optic_outer
133+
# Try to move the outermost optic from test_vn into test_optic.
134+
# If test_vn is already an identity, we can't, so we stop.
135+
o = getoptic(test_vn)
136+
o == identity && error("getvalue: $(vn) was not found in the values provided")
137+
test_vn = VarName{sym}(_init(o))
138+
test_optic = normalise(_last(o) test_optic)
137139
end
138140
end
139-
return error("getvalue: $(vn) was not found in the values provided")
140141
end
141142

142143
"""
@@ -209,23 +210,25 @@ function hasvalue(vals::AbstractDict{<:VarName}, vn::VarName{sym}) where {sym}
209210
# First we check if `vn` is present as is.
210211
haskey(vals, vn) && return true
211212

212-
# Otherwise, we start by testing the bare `vn` (e.g., if `vn` is `x[1][2]`,
213-
# we start by checking if `x` is present). We will then keep adding optics
214-
# to `test_optic`, either until we find a key that is present, or until we
215-
# run out of optics to test (which is determined by _inner(test_optic) ==
216-
# identity).
217-
test_vn = VarName{sym}()
218-
test_optic = getoptic(vn)
213+
# Otherwise, we start by testing the `vn` one level up (e.g., if `vn` is
214+
# `x[1][2]`, we start by checking if `x[1]` is present, then `x`). We will
215+
# then keep removing optics from `test_optic`, either until we find a key
216+
# that is present, or until we run out of optics to test (which happens
217+
# after getoptic(test_vn) == identity).
218+
o = getoptic(vn)
219+
test_vn = VarName{sym}(_init(o))
220+
test_optic = _last(o)
219221

220-
while _inner(test_optic) != identity
222+
while true
221223
if haskey(vals, test_vn) && canview(test_optic, vals[test_vn])
222224
return true
223225
else
224-
# Move the innermost optic into test_vn
225-
test_optic_outer = _outer(test_optic)
226-
test_optic_inner = _inner(test_optic)
227-
test_vn = VarName{sym}(test_optic_inner getoptic(test_vn))
228-
test_optic = test_optic_outer
226+
# Try to move the outermost optic from test_vn into test_optic.
227+
# If test_vn is already an identity, we can't, so we stop.
228+
o = getoptic(test_vn)
229+
o == identity && return false
230+
test_vn = VarName{sym}(_init(o))
231+
test_optic = normalise(_last(o) test_optic)
229232
end
230233
end
231234
return false

src/varname.jl

Lines changed: 104 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -963,69 +963,131 @@ string_to_varname(str::AbstractString) = dict_to_varname(JSON.parse(str))
963963
### Prefixing and unprefixing
964964

965965
"""
966-
_strip_identity(optic)
966+
_head(optic)
967967
968-
Remove identity lenses from composed optics.
969-
"""
970-
_strip_identity(::Base.ComposedFunction{typeof(identity),typeof(identity)}) = identity
971-
function _strip_identity(o::Base.ComposedFunction{Outer,typeof(identity)}) where {Outer}
972-
return _strip_identity(o.outer)
973-
end
974-
function _strip_identity(o::Base.ComposedFunction{typeof(identity),Inner}) where {Inner}
975-
return _strip_identity(o.inner)
976-
end
977-
_strip_identity(o::Base.ComposedFunction) = o
978-
_strip_identity(o::Accessors.PropertyLens) = o
979-
_strip_identity(o::Accessors.IndexLens) = o
980-
_strip_identity(o::typeof(identity)) = o
981-
982-
"""
983-
_inner(optic)
968+
Get the innermost layer of an optic.
984969
985-
Get the innermost (non-identity) layer of an optic.
970+
!!! note
971+
Does not perform optic normalisation. You may wish to call
972+
`normalise(optic)` before using this function if the optic you are passing
973+
was not obtained from a VarName.
986974
987975
```jldoctest; setup=:(using Accessors)
988-
julia> AbstractPPL._inner(Accessors.@o _.a.b.c)
976+
julia> AbstractPPL._head(Accessors.@o _.a.b.c)
989977
(@o _.a)
990978
991-
julia> AbstractPPL._inner(Accessors.@o _[1][2][3])
979+
julia> AbstractPPL._head(Accessors.@o _[1][2][3])
992980
(@o _[1])
993981
994-
julia> AbstractPPL._inner(Accessors.@o _)
982+
julia> AbstractPPL._head(Accessors.@o _.a)
983+
(@o _.a)
984+
985+
julia> AbstractPPL._head(Accessors.@o _[1])
986+
(@o _[1])
987+
988+
julia> AbstractPPL._head(Accessors.@o _)
995989
identity (generic function with 1 method)
996990
```
997991
"""
998-
_inner(o::Base.ComposedFunction{Outer,Inner}) where {Outer,Inner} = o.inner
999-
_inner(o::Accessors.PropertyLens) = o
1000-
_inner(o::Accessors.IndexLens) = o
1001-
_inner(o::typeof(identity)) = o
992+
_head(o::ComposedFunction{Outer,Inner}) where {Outer,Inner} = o.inner
993+
_head(o::Accessors.PropertyLens) = o
994+
_head(o::Accessors.IndexLens) = o
995+
_head(::typeof(identity)) = identity
1002996

1003997
"""
1004-
_outer(optic)
998+
_tail(optic)
999+
1000+
Get everything but the innermost layer of an optic.
10051001
1006-
Get the outer layer of an optic.
1002+
!!! note
1003+
Does not perform optic normalisation. You may wish to call
1004+
`normalise(optic)` before using this function if the optic you are passing
1005+
was not obtained from a VarName.
10071006
10081007
```jldoctest; setup=:(using Accessors)
1009-
julia> AbstractPPL._outer(Accessors.@o _.a.b.c)
1008+
julia> AbstractPPL._tail(Accessors.@o _.a.b.c)
10101009
(@o _.b.c)
10111010
1012-
julia> AbstractPPL._outer(Accessors.@o _[1][2][3])
1011+
julia> AbstractPPL._tail(Accessors.@o _[1][2][3])
10131012
(@o _[2][3])
10141013
1015-
julia> AbstractPPL._outer(Accessors.@o _.a)
1014+
julia> AbstractPPL._tail(Accessors.@o _.a)
10161015
identity (generic function with 1 method)
10171016
1018-
julia> AbstractPPL._outer(Accessors.@o _[1])
1017+
julia> AbstractPPL._tail(Accessors.@o _[1])
10191018
identity (generic function with 1 method)
10201019
1021-
julia> AbstractPPL._outer(Accessors.@o _)
1020+
julia> AbstractPPL._tail(Accessors.@o _)
10221021
identity (generic function with 1 method)
10231022
```
10241023
"""
1025-
_outer(o::Base.ComposedFunction{Outer,Inner}) where {Outer,Inner} = o.outer
1026-
_outer(::Accessors.PropertyLens) = identity
1027-
_outer(::Accessors.IndexLens) = identity
1028-
_outer(::typeof(identity)) = identity
1024+
_tail(o::ComposedFunction{Outer,Inner}) where {Outer,Inner} = o.outer
1025+
_tail(::Accessors.PropertyLens) = identity
1026+
_tail(::Accessors.IndexLens) = identity
1027+
_tail(::typeof(identity)) = identity
1028+
1029+
"""
1030+
_last(optic)
1031+
1032+
Get the outermost layer of an optic.
1033+
1034+
!!! note
1035+
Does not perform optic normalisation. You may wish to call
1036+
`normalise(optic)` before using this function if the optic you are passing
1037+
was not obtained from a VarName.
1038+
1039+
```jldoctest; setup=:(using Accessors)
1040+
julia> AbstractPPL._last(Accessors.@o _.a.b.c)
1041+
(@o _.c)
1042+
1043+
julia> AbstractPPL._last(Accessors.@o _[1][2][3])
1044+
(@o _[3])
1045+
1046+
julia> AbstractPPL._last(Accessors.@o _.a)
1047+
(@o _.a)
1048+
1049+
julia> AbstractPPL._last(Accessors.@o _[1])
1050+
(@o _[1])
1051+
1052+
julia> AbstractPPL._last(Accessors.@o _)
1053+
identity (generic function with 1 method)
1054+
```
1055+
"""
1056+
_last(o::ComposedFunction{Outer,Inner}) where {Outer,Inner} = _last(o.outer)
1057+
_last(o::Accessors.PropertyLens) = o
1058+
_last(o::Accessors.IndexLens) = o
1059+
_last(::typeof(identity)) = identity
1060+
1061+
"""
1062+
_init(optic)
1063+
1064+
Get everything but the outermost layer of an optic.
1065+
1066+
!!! note
1067+
Does not perform optic normalisation. You may wish to call
1068+
`normalise(optic)` before using this function if the optic you are passing
1069+
was not obtained from a VarName.
1070+
1071+
```jldoctest; setup=:(using Accessors)
1072+
julia> AbstractPPL._init(Accessors.@o _.a.b.c)
1073+
(@o _.a.b)
1074+
1075+
julia> AbstractPPL._init(Accessors.@o _[1][2][3])
1076+
(@o _[1][2])
1077+
1078+
julia> AbstractPPL._init(Accessors.@o _.a)
1079+
identity (generic function with 1 method)
1080+
1081+
julia> AbstractPPL._init(Accessors.@o _[1])
1082+
identity (generic function with 1 method)
1083+
1084+
julia> AbstractPPL._init(Accessors.@o _)
1085+
identity (generic function with 1 method)
1086+
"""
1087+
_init(o::ComposedFunction{Outer,Inner}) where {Outer,Inner} = _init(o.outer) o.inner
1088+
_init(::Accessors.PropertyLens) = identity
1089+
_init(::Accessors.IndexLens) = identity
1090+
_init(::typeof(identity)) = identity
10291091

10301092
"""
10311093
optic_to_vn(optic)
@@ -1058,11 +1120,11 @@ function optic_to_vn(::Accessors.PropertyLens{sym}) where {sym}
10581120
return VarName{sym}()
10591121
end
10601122
function optic_to_vn(
1061-
o::Base.ComposedFunction{Outer,Accessors.PropertyLens{sym}}
1123+
o::ComposedFunction{Outer,Accessors.PropertyLens{sym}}
10621124
) where {Outer,sym}
10631125
return VarName{sym}(o.outer)
10641126
end
1065-
optic_to_vn(o::Base.ComposedFunction) = optic_to_vn(normalise(o))
1127+
optic_to_vn(o::ComposedFunction) = optic_to_vn(normalise(o))
10661128
function optic_to_vn(@nospecialize(o))
10671129
msg = "optic_to_vn: could not convert optic `$o` to a VarName"
10681130
throw(ArgumentError(msg))
@@ -1077,14 +1139,14 @@ function unprefix_optic(optic, optic_prefix)
10771139
optic = normalise(optic)
10781140
optic_prefix = normalise(optic_prefix)
10791141
# strip one layer of the optic and check for equality
1080-
inner = _inner(optic)
1081-
inner_prefix = _inner(optic_prefix)
1082-
if inner != inner_prefix
1142+
head = _head(optic)
1143+
head_prefix = _head(optic_prefix)
1144+
if head != head_prefix
10831145
msg = "could not remove prefix $(optic_prefix) from optic $(optic)"
10841146
throw(ArgumentError(msg))
10851147
end
10861148
# recurse
1087-
return unprefix_optic(_outer(optic), _outer(optic_prefix))
1149+
return unprefix_optic(_tail(optic), _tail(optic_prefix))
10881150
end
10891151

10901152
"""

test/hasvalue.jl

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
@testset "base getvalue + hasvalue" begin
2-
@testset "NamedTuple" begin
2+
@testset "basic NamedTuple" begin
33
nt = (a=[1], b=2, c=(x=3,), d=[1.0 0.5; 0.5 1.0])
44
@test hasvalue(nt, @varname(a))
55
@test getvalue(nt, @varname(a)) == [1]
@@ -32,8 +32,8 @@
3232
@test !hasvalue(nt, @varname(d[3, :]))
3333
end
3434

35-
@testset "Dict" begin
36-
# same tests as above
35+
@testset "basic Dict" begin
36+
# same tests as for NamedTuple
3737
d = Dict(
3838
@varname(a) => [1],
3939
@varname(b) => 2,
@@ -68,8 +68,9 @@
6868
@test !hasvalue(d, @varname(c.x[1]))
6969
@test !hasvalue(d, @varname(c.y))
7070
@test !hasvalue(d, @varname(d[1, 3]))
71+
end
7172

72-
# extra ones since Dict can have weird keys
73+
@testset "Dict with non-identity varname keys" begin
7374
d = Dict(
7475
@varname(a[1]) => [1.0, 2.0],
7576
@varname(b.x) => [3.0],
@@ -98,6 +99,27 @@
9899
@test !hasvalue(d, @varname(c[1]))
99100
@test !hasvalue(d, @varname(c[2].x))
100101
end
102+
103+
@testset "Dict with redundancy" begin
104+
d1 = Dict(@varname(x) => [[[[1.0]]]])
105+
d2 = Dict(@varname(x[1]) => [[[2.0]]])
106+
d3 = Dict(@varname(x[1][1]) => [[3.0]])
107+
d4 = Dict(@varname(x[1][1][1]) => [4.0])
108+
d5 = Dict(@varname(x[1][1][1][1]) => 5.0)
109+
110+
d = Dict{VarName,Any}()
111+
for (new_dict, expected_value) in
112+
zip((d1, d2, d3, d4, d5), (1.0, 2.0, 3.0, 4.0, 5.0))
113+
d = merge(d, new_dict)
114+
@test hasvalue(d, @varname(x[1][1][1][1]))
115+
@test getvalue(d, @varname(x[1][1][1][1])) == expected_value
116+
# for good measure
117+
@test !hasvalue(d, @varname(x[1][1][1][2]))
118+
@test !hasvalue(d, @varname(x[1][1][2][1]))
119+
@test !hasvalue(d, @varname(x[1][2][1][1]))
120+
@test !hasvalue(d, @varname(x[2][1][1][1]))
121+
end
122+
end
101123
end
102124

103125
@testset "with Distributions: getvalue + hasvalue" begin

0 commit comments

Comments
 (0)