Skip to content

Commit 6a01588

Browse files
committed
Implement fallback {has,get}value methods for NamedTuple + Distribution
1 parent a998af6 commit 6a01588

File tree

1 file changed

+33
-2
lines changed

1 file changed

+33
-2
lines changed

ext/AbstractPPLDistributionsExt.jl

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ end
8080

8181
"""
8282
hasvalue(
83-
vals::AbstractDict,
83+
vals::Union{AbstractDict,NamedTuple},
8484
vn::VarName,
8585
dist::Distribution;
8686
error_on_incomplete::Bool=false
@@ -98,6 +98,11 @@ the values needed for `vn` are present, but others are not. This may help
9898
to detect invalid cases where the user has provided e.g. data of the wrong
9999
shape.
100100
101+
Note that this check is only possible if a Dict is passed, because the key type
102+
of a NamedTuple (i.e., Symbol) is not rich enough to carry indexing
103+
information. If this method is called with a NamedTuple, it will just defer
104+
to `hasvalue(vals, vn)`.
105+
101106
For example:
102107
103108
```jldoctest; setup=:(using Distributions, LinearAlgebra))
@@ -114,6 +119,16 @@ ERROR: hasvalue: only partial values for `x` found in the values provided
114119
[...]
115120
```
116121
"""
122+
function AbstractPPL.hasvalue(
123+
vals::NamedTuple,
124+
vn::VarName,
125+
dist::Distributions.Distribution;
126+
error_on_incomplete::Bool=false,
127+
)
128+
# NamedTuples can't have such complicated hierarchies, so it's safe to
129+
# defer to the simpler `hasvalue(vals, vn)`.
130+
return hasvalue(vals, vn)
131+
end
117132
function AbstractPPL.hasvalue(
118133
vals::AbstractDict,
119134
vn::VarName,
@@ -169,7 +184,11 @@ function AbstractPPL.hasvalue(
169184
end
170185

171186
"""
172-
getvalue(vals::AbstractDict, vn::VarName, dist::Distribution)
187+
getvalue(
188+
vals::Union{AbstractDict,NamedTuple},
189+
vn::VarName,
190+
dist::Distribution
191+
)
173192
174193
Retrieve the value of `vn` from `vals`, using the distribution `dist` to
175194
reconstruct the value if necessary.
@@ -178,6 +197,11 @@ This is a more general version of `getvalue(vals, vn)`, in that even if `vn`
178197
itself is not inside `vals`, it can still reconstruct the value of `vn`
179198
from sub-values of `vn` that are present in `vals`.
180199
200+
Note that this reconstruction is only possible if a Dict is passed, because the
201+
key type of a NamedTuple (i.e., Symbol) is not rich enough to carry indexing
202+
information. If this method is called with a NamedTuple, it will just defer
203+
to `getvalue(vals, vn)`.
204+
181205
For example:
182206
183207
```jldoctest; setup=:(using Distributions, LinearAlgebra))
@@ -194,6 +218,13 @@ ERROR: getvalue: `x` was not found in the values provided
194218
[...]
195219
```
196220
"""
221+
function AbstractPPL.getvalue(
222+
vals::NamedTuple, vn::VarName, dist::Distributions.Distribution
223+
)
224+
# NamedTuples can't have such complicated hierarchies, so it's safe to
225+
# defer to the simpler `getvalue(vals, vn)`.
226+
return getvalue(vals, vn)
227+
end
197228
function AbstractPPL.getvalue(
198229
vals::AbstractDict, vn::VarName, dist::Distributions.Distribution;
199230
)

0 commit comments

Comments
 (0)