@@ -2,23 +2,112 @@ module AbstractPPLDistributionsExt
2
2
3
3
using AbstractPPL: AbstractPPL, VarName, Accessors
4
4
using Distributions: Distributions
5
+ using LinearAlgebra: Cholesky, LowerTriangular, UpperTriangular
6
+
7
+ #=
8
+ This section is copied from Accessors.jl's documentation:
9
+ https://juliaobjects.github.io/Accessors.jl/stable/examples/custom_macros/
10
+
11
+ It defines a wrapper that, when called with `set`, mutates the original value
12
+ rather than returning a new value. We need this because the non-mutating optics
13
+ don't work for triangular matrices (and hence LKJCholesky): see
14
+ https://github.com/JuliaObjects/Accessors.jl/issues/203
15
+ =#
16
+ struct Lens!{L}
17
+ pure:: L
18
+ end
19
+ (l:: Lens! )(o) = l. pure (o)
20
+ function Accessors. set (o, l:: Lens!{<:ComposedFunction} , val)
21
+ o_inner = l. pure. inner (o)
22
+ return Accessors. set (o_inner, Lens! (l. pure. outer), val)
23
+ end
24
+ function Accessors. set (o, l:: Lens!{Accessors.PropertyLens{prop}} , val) where {prop}
25
+ setproperty! (o, prop, val)
26
+ return o
27
+ end
28
+ function Accessors. set (o, l:: Lens!{<:Accessors.IndexLens} , val)
29
+ o[l. pure. indices... ] = val
30
+ return o
31
+ end
32
+
33
+ """
34
+ get_optics(dist::MultivariateDistribution)
35
+ get_optics(dist::MatrixDistribution)
36
+ get_optics(dist::LKJCholesky)
37
+
38
+ Return a complete set of optics for each element of the type returned by `rand(dist)`.
39
+ """
40
+ function get_optics (
41
+ dist:: Union{Distributions.MultivariateDistribution,Distributions.MatrixDistribution}
42
+ )
43
+ indices = CartesianIndices (size (dist))
44
+ return map (idx -> Accessors. IndexLens (idx. I), indices)
45
+ end
46
+ function get_optics (dist:: Distributions.LKJCholesky )
47
+ is_up = dist. uplo == ' U'
48
+ cartesian_indices = filter (CartesianIndices (size (dist))) do cartesian_index
49
+ i, j = cartesian_index. I
50
+ is_up ? i <= j : i >= j
51
+ end
52
+ # there is an additional layer as we need to access `.L` or `.U` before we
53
+ # can index into it
54
+ field_lens = is_up ? (Accessors. @o _. U) : (Accessors. @o _. L)
55
+ return map (idx -> Accessors. IndexLens (idx. I) ∘ field_lens, cartesian_indices)
56
+ end
57
+
58
+ """
59
+ make_empty_value(dist::MultivariateDistribution)
60
+ make_empty_value(dist::MatrixDistribution)
61
+ make_empty_value(dist::LKJCholesky)
62
+
63
+ Construct a fresh value filled with zeros that corresponds to the size of `dist`.
64
+
65
+ For all distributions that this function accepts, it should hold that
66
+ `o(make_empty_value(dist))` is zero for all `o` in `get_optics(dist)`.
67
+ """
68
+ function make_empty_value (
69
+ dist:: Union{Distributions.MultivariateDistribution,Distributions.MatrixDistribution}
70
+ )
71
+ return zeros (size (dist))
72
+ end
73
+ function make_empty_value (dist:: Distributions.LKJCholesky )
74
+ if dist. uplo == ' U'
75
+ return Cholesky (UpperTriangular (zeros (size (dist))))
76
+ else
77
+ return Cholesky (LowerTriangular (zeros (size (dist))))
78
+ end
79
+ end
5
80
6
81
# TODO (penelopeysm): Figure out tuple / namedtuple distributions, and LKJCholesky (grr)
7
82
function AbstractPPL. hasvalue (
8
- vals:: AbstractDict , vn:: VarName , dist:: Distributions.Distribution
83
+ vals:: AbstractDict ,
84
+ vn:: VarName ,
85
+ dist:: Distributions.Distribution ;
86
+ error_on_incomplete:: Bool = false ,
9
87
)
10
88
@warn " `hasvalue(vals, vn, dist)` is not implemented for $(typeof (dist)) ; falling back to `hasvalue(vals, vn)`."
11
89
return AbstractPPL. hasvalue (vals, vn)
12
90
end
13
91
function AbstractPPL. hasvalue (
14
- vals:: AbstractDict , vn:: VarName , :: Distributions.UnivariateDistribution
92
+ vals:: AbstractDict ,
93
+ vn:: VarName ,
94
+ :: Distributions.UnivariateDistribution ;
95
+ error_on_incomplete:: Bool = false ,
15
96
)
97
+ # TODO (penelopeysm): We could also implement a check for the type to catch
98
+ # invalid values. Unsure if that is worth it. It may be easier to just let
99
+ # the user handle it.
16
100
return AbstractPPL. hasvalue (vals, vn)
17
101
end
18
102
function AbstractPPL. hasvalue (
19
103
vals:: AbstractDict{<:VarName} ,
20
104
vn:: VarName{sym} ,
21
- dist:: Union{Distributions.MultivariateDistribution,Distributions.MatrixDistribution} ,
105
+ dist:: Union {
106
+ Distributions. MultivariateDistribution,
107
+ Distributions. MatrixDistribution,
108
+ Distributions. LKJCholesky,
109
+ };
110
+ error_on_incomplete:: Bool = false ,
22
111
) where {sym}
23
112
# If `vn` is present as-is, then we are good
24
113
AbstractPPL. hasvalue (vals, vn) && return true
@@ -30,13 +119,66 @@ function AbstractPPL.hasvalue(
30
119
# To do this, we get the size of the distribution and iterate over all
31
120
# possible indices. If every index can be found in `subsumed_keys`, then we
32
121
# can return true.
33
- sz = size (dist)
34
- for idx in Iterators. product (map (Base. OneTo, sz)... )
35
- new_optic = Accessors. IndexLens (idx) ∘ AbstractPPL. getoptic (vn)
36
- new_vn = VarName {sym} (new_optic)
37
- AbstractPPL. hasvalue (vals, new_vn) || return false
122
+ optics = get_optics (dist)
123
+ original_optic = AbstractPPL. getoptic (vn)
124
+ expected_vns = map (o -> VarName {sym} (o ∘ original_optic), optics)
125
+ if all (sub_vn -> AbstractPPL. hasvalue (vals, sub_vn), expected_vns)
126
+ return true
127
+ else
128
+ if error_on_incomplete &&
129
+ any (sub_vn -> AbstractPPL. hasvalue (vals, sub_vn), expected_vns)
130
+ error (" hasvalue: only partial values for `$vn ` found in the values provided" )
131
+ end
132
+ return false
133
+ end
134
+ end
135
+
136
+ function AbstractPPL. getvalue (
137
+ vals:: AbstractDict , vn:: VarName , dist:: Distributions.Distribution ;
138
+ )
139
+ @warn " `getvalue(vals, vn, dist)` is not implemented for $(typeof (dist)) ; falling back to `getvalue(vals, vn)`."
140
+ return AbstractPPL. getvalue (vals, vn)
141
+ end
142
+ function AbstractPPL. getvalue (
143
+ vals:: AbstractDict , vn:: VarName , :: Distributions.UnivariateDistribution ;
144
+ )
145
+ # TODO (penelopeysm): We could also implement a check for the type to catch
146
+ # invalid values. Unsure if that is worth it. It may be easier to just let
147
+ # the user handle it.
148
+ return AbstractPPL. getvalue (vals, vn)
149
+ end
150
+ function AbstractPPL. getvalue (
151
+ vals:: AbstractDict{<:VarName} ,
152
+ vn:: VarName{sym} ,
153
+ dist:: Union {
154
+ Distributions. MultivariateDistribution,
155
+ Distributions. MatrixDistribution,
156
+ Distributions. LKJCholesky,
157
+ };
158
+ ) where {sym}
159
+ # If `vn` is present as-is, then we can just return that
160
+ AbstractPPL. hasvalue (vals, vn) && return AbstractPPL. getvalue (vals, vn)
161
+ # If not, then we need to start looking inside `vals`, in exactly the
162
+ # same way we did for `hasvalue`.
163
+ optics = get_optics (dist)
164
+ original_optic = AbstractPPL. getoptic (vn)
165
+ expected_vns = map (o -> VarName {sym} (o ∘ original_optic), optics)
166
+ if all (sub_vn -> AbstractPPL. hasvalue (vals, sub_vn), expected_vns)
167
+ # Reconstruct the value index by index.
168
+ value = make_empty_value (dist)
169
+ for (o, sub_vn) in zip (optics, expected_vns)
170
+ # Retrieve the value of this given index
171
+ sub_value = AbstractPPL. getvalue (vals, sub_vn)
172
+ # Set it inside the value we're reconstructing.
173
+ # Note: `o` is normally non-mutating. We have to wrap it in `Lens!`
174
+ # to make it mutating, because Cholesky distributions are broken
175
+ # by https://github.com/JuliaObjects/Accessors.jl/issues/203.
176
+ Accessors. set (value, Lens! (o), sub_value)
177
+ end
178
+ return value
179
+ else
180
+ error (" getvalue: $(vn) was not found in the values provided" )
38
181
end
39
- return true
40
182
end
41
183
42
184
end
0 commit comments