80
80
81
81
"""
82
82
hasvalue(
83
- vals::AbstractDict,
83
+ vals::Union{ AbstractDict,NamedTuple} ,
84
84
vn::VarName,
85
85
dist::Distribution;
86
86
error_on_incomplete::Bool=false
@@ -98,6 +98,11 @@ the values needed for `vn` are present, but others are not. This may help
98
98
to detect invalid cases where the user has provided e.g. data of the wrong
99
99
shape.
100
100
101
+ Note that this check is only possible if a Dict is passed, because the key type
102
+ of a NamedTuple (i.e., Symbol) is not rich enough to carry indexing
103
+ information. If this method is called with a NamedTuple, it will just defer
104
+ to `hasvalue(vals, vn)`.
105
+
101
106
For example:
102
107
103
108
```jldoctest; setup=:(using Distributions, LinearAlgebra))
@@ -114,6 +119,16 @@ ERROR: hasvalue: only partial values for `x` found in the values provided
114
119
[...]
115
120
```
116
121
"""
122
+ function AbstractPPL. hasvalue (
123
+ vals:: NamedTuple ,
124
+ vn:: VarName ,
125
+ dist:: Distributions.Distribution ;
126
+ error_on_incomplete:: Bool = false ,
127
+ )
128
+ # NamedTuples can't have such complicated hierarchies, so it's safe to
129
+ # defer to the simpler `hasvalue(vals, vn)`.
130
+ return hasvalue (vals, vn)
131
+ end
117
132
function AbstractPPL. hasvalue (
118
133
vals:: AbstractDict ,
119
134
vn:: VarName ,
@@ -169,7 +184,11 @@ function AbstractPPL.hasvalue(
169
184
end
170
185
171
186
"""
172
- getvalue(vals::AbstractDict, vn::VarName, dist::Distribution)
187
+ getvalue(
188
+ vals::Union{AbstractDict,NamedTuple},
189
+ vn::VarName,
190
+ dist::Distribution
191
+ )
173
192
174
193
Retrieve the value of `vn` from `vals`, using the distribution `dist` to
175
194
reconstruct the value if necessary.
@@ -178,6 +197,11 @@ This is a more general version of `getvalue(vals, vn)`, in that even if `vn`
178
197
itself is not inside `vals`, it can still reconstruct the value of `vn`
179
198
from sub-values of `vn` that are present in `vals`.
180
199
200
+ Note that this reconstruction is only possible if a Dict is passed, because the
201
+ key type of a NamedTuple (i.e., Symbol) is not rich enough to carry indexing
202
+ information. If this method is called with a NamedTuple, it will just defer
203
+ to `getvalue(vals, vn)`.
204
+
181
205
For example:
182
206
183
207
```jldoctest; setup=:(using Distributions, LinearAlgebra))
@@ -194,6 +218,15 @@ ERROR: getvalue: `x` was not found in the values provided
194
218
[...]
195
219
```
196
220
"""
221
+ function AbstractPPL. getvalue (
222
+ vals:: NamedTuple ,
223
+ vn:: VarName ,
224
+ dist:: Distributions.Distribution
225
+ )
226
+ # NamedTuples can't have such complicated hierarchies, so it's safe to
227
+ # defer to the simpler `getvalue(vals, vn)`.
228
+ return getvalue (vals, vn)
229
+ end
197
230
function AbstractPPL. getvalue (
198
231
vals:: AbstractDict , vn:: VarName , dist:: Distributions.Distribution ;
199
232
)
0 commit comments