Skip to content

Commit a736e70

Browse files
committed
Finish implementing distributions methods
1 parent 9291e07 commit a736e70

File tree

4 files changed

+209
-14
lines changed

4 files changed

+209
-14
lines changed

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,17 @@ StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
1515

1616
[weakdeps]
1717
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
18+
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1819

1920
[extensions]
20-
AbstractPPLDistributionsExt = ["Distributions"]
21+
AbstractPPLDistributionsExt = ["Distributions", "LinearAlgebra"]
2122

2223
[compat]
2324
AbstractMCMC = "2, 3, 4, 5"
2425
Accessors = "0.1"
2526
DensityInterface = "0.4"
2627
Distributions = "0.25"
28+
LinearAlgebra = "<0.0.1, 1.11"
2729
JSON = "0.19 - 0.21"
2830
Random = "1.6"
2931
StatsBase = "0.32, 0.33, 0.34"

ext/AbstractPPLDistributionsExt.jl

Lines changed: 151 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,23 +2,112 @@ module AbstractPPLDistributionsExt
22

33
using AbstractPPL: AbstractPPL, VarName, Accessors
44
using Distributions: Distributions
5+
using LinearAlgebra: Cholesky, LowerTriangular, UpperTriangular
6+
7+
#=
8+
This section is copied from Accessors.jl's documentation:
9+
https://juliaobjects.github.io/Accessors.jl/stable/examples/custom_macros/
10+
11+
It defines a wrapper that, when called with `set`, mutates the original value
12+
rather than returning a new value. We need this because the non-mutating optics
13+
don't work for triangular matrices (and hence LKJCholesky): see
14+
https://github.com/JuliaObjects/Accessors.jl/issues/203
15+
=#
16+
struct Lens!{L}
17+
pure::L
18+
end
19+
(l::Lens!)(o) = l.pure(o)
20+
function Accessors.set(o, l::Lens!{<:ComposedFunction}, val)
21+
o_inner = l.pure.inner(o)
22+
return Accessors.set(o_inner, Lens!(l.pure.outer), val)
23+
end
24+
function Accessors.set(o, l::Lens!{Accessors.PropertyLens{prop}}, val) where {prop}
25+
setproperty!(o, prop, val)
26+
return o
27+
end
28+
function Accessors.set(o, l::Lens!{<:Accessors.IndexLens}, val)
29+
o[l.pure.indices...] = val
30+
return o
31+
end
32+
33+
"""
34+
get_optics(dist::MultivariateDistribution)
35+
get_optics(dist::MatrixDistribution)
36+
get_optics(dist::LKJCholesky)
37+
38+
Return a complete set of optics for each element of the type returned by `rand(dist)`.
39+
"""
40+
function get_optics(
41+
dist::Union{Distributions.MultivariateDistribution,Distributions.MatrixDistribution}
42+
)
43+
indices = CartesianIndices(size(dist))
44+
return map(idx -> Accessors.IndexLens(idx.I), indices)
45+
end
46+
function get_optics(dist::Distributions.LKJCholesky)
47+
is_up = dist.uplo == 'U'
48+
cartesian_indices = filter(CartesianIndices(size(dist))) do cartesian_index
49+
i, j = cartesian_index.I
50+
is_up ? i <= j : i >= j
51+
end
52+
# there is an additional layer as we need to access `.L` or `.U` before we
53+
# can index into it
54+
field_lens = is_up ? (Accessors.@o _.U) : (Accessors.@o _.L)
55+
return map(idx -> Accessors.IndexLens(idx.I) field_lens, cartesian_indices)
56+
end
57+
58+
"""
59+
make_empty_value(dist::MultivariateDistribution)
60+
make_empty_value(dist::MatrixDistribution)
61+
make_empty_value(dist::LKJCholesky)
62+
63+
Construct a fresh value filled with zeros that corresponds to the size of `dist`.
64+
65+
For all distributions that this function accepts, it should hold that
66+
`o(make_empty_value(dist))` is zero for all `o` in `get_optics(dist)`.
67+
"""
68+
function make_empty_value(
69+
dist::Union{Distributions.MultivariateDistribution,Distributions.MatrixDistribution}
70+
)
71+
return zeros(size(dist))
72+
end
73+
function make_empty_value(dist::Distributions.LKJCholesky)
74+
if dist.uplo == 'U'
75+
return Cholesky(UpperTriangular(zeros(size(dist))))
76+
else
77+
return Cholesky(LowerTriangular(zeros(size(dist))))
78+
end
79+
end
580

681
# TODO(penelopeysm): Figure out tuple / namedtuple distributions, and LKJCholesky (grr)
782
function AbstractPPL.hasvalue(
8-
vals::AbstractDict, vn::VarName, dist::Distributions.Distribution
83+
vals::AbstractDict,
84+
vn::VarName,
85+
dist::Distributions.Distribution;
86+
error_on_incomplete::Bool=false,
987
)
1088
@warn "`hasvalue(vals, vn, dist)` is not implemented for $(typeof(dist)); falling back to `hasvalue(vals, vn)`."
1189
return AbstractPPL.hasvalue(vals, vn)
1290
end
1391
function AbstractPPL.hasvalue(
14-
vals::AbstractDict, vn::VarName, ::Distributions.UnivariateDistribution
92+
vals::AbstractDict,
93+
vn::VarName,
94+
::Distributions.UnivariateDistribution;
95+
error_on_incomplete::Bool=false,
1596
)
97+
# TODO(penelopeysm): We could also implement a check for the type to catch
98+
# invalid values. Unsure if that is worth it. It may be easier to just let
99+
# the user handle it.
16100
return AbstractPPL.hasvalue(vals, vn)
17101
end
18102
function AbstractPPL.hasvalue(
19103
vals::AbstractDict{<:VarName},
20104
vn::VarName{sym},
21-
dist::Union{Distributions.MultivariateDistribution,Distributions.MatrixDistribution},
105+
dist::Union{
106+
Distributions.MultivariateDistribution,
107+
Distributions.MatrixDistribution,
108+
Distributions.LKJCholesky,
109+
};
110+
error_on_incomplete::Bool=false,
22111
) where {sym}
23112
# If `vn` is present as-is, then we are good
24113
AbstractPPL.hasvalue(vals, vn) && return true
@@ -30,13 +119,66 @@ function AbstractPPL.hasvalue(
30119
# To do this, we get the size of the distribution and iterate over all
31120
# possible indices. If every index can be found in `subsumed_keys`, then we
32121
# can return true.
33-
sz = size(dist)
34-
for idx in Iterators.product(map(Base.OneTo, sz)...)
35-
new_optic = Accessors.IndexLens(idx) AbstractPPL.getoptic(vn)
36-
new_vn = VarName{sym}(new_optic)
37-
AbstractPPL.hasvalue(vals, new_vn) || return false
122+
optics = get_optics(dist)
123+
original_optic = AbstractPPL.getoptic(vn)
124+
expected_vns = map(o -> VarName{sym}(o original_optic), optics)
125+
if all(sub_vn -> AbstractPPL.hasvalue(vals, sub_vn), expected_vns)
126+
return true
127+
else
128+
if error_on_incomplete &&
129+
any(sub_vn -> AbstractPPL.hasvalue(vals, sub_vn), expected_vns)
130+
error("hasvalue: only partial values for `$vn` found in the values provided")
131+
end
132+
return false
133+
end
134+
end
135+
136+
function AbstractPPL.getvalue(
137+
vals::AbstractDict, vn::VarName, dist::Distributions.Distribution;
138+
)
139+
@warn "`getvalue(vals, vn, dist)` is not implemented for $(typeof(dist)); falling back to `getvalue(vals, vn)`."
140+
return AbstractPPL.getvalue(vals, vn)
141+
end
142+
function AbstractPPL.getvalue(
143+
vals::AbstractDict, vn::VarName, ::Distributions.UnivariateDistribution;
144+
)
145+
# TODO(penelopeysm): We could also implement a check for the type to catch
146+
# invalid values. Unsure if that is worth it. It may be easier to just let
147+
# the user handle it.
148+
return AbstractPPL.getvalue(vals, vn)
149+
end
150+
function AbstractPPL.getvalue(
151+
vals::AbstractDict{<:VarName},
152+
vn::VarName{sym},
153+
dist::Union{
154+
Distributions.MultivariateDistribution,
155+
Distributions.MatrixDistribution,
156+
Distributions.LKJCholesky,
157+
};
158+
) where {sym}
159+
# If `vn` is present as-is, then we can just return that
160+
AbstractPPL.hasvalue(vals, vn) && return AbstractPPL.getvalue(vals, vn)
161+
# If not, then we need to start looking inside `vals`, in exactly the
162+
# same way we did for `hasvalue`.
163+
optics = get_optics(dist)
164+
original_optic = AbstractPPL.getoptic(vn)
165+
expected_vns = map(o -> VarName{sym}(o original_optic), optics)
166+
if all(sub_vn -> AbstractPPL.hasvalue(vals, sub_vn), expected_vns)
167+
# Reconstruct the value index by index.
168+
value = make_empty_value(dist)
169+
for (o, sub_vn) in zip(optics, expected_vns)
170+
# Retrieve the value of this given index
171+
sub_value = AbstractPPL.getvalue(vals, sub_vn)
172+
# Set it inside the value we're reconstructing.
173+
# Note: `o` is normally non-mutating. We have to wrap it in `Lens!`
174+
# to make it mutating, because Cholesky distributions are broken
175+
# by https://github.com/JuliaObjects/Accessors.jl/issues/203.
176+
Accessors.set(value, Lens!(o), sub_value)
177+
end
178+
return value
179+
else
180+
error("getvalue: $(vn) was not found in the values provided")
38181
end
39-
return true
40182
end
41183

42184
end

src/varname.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1098,7 +1098,9 @@ identity (generic function with 1 method)
10981098
"""
10991099
# This one needs normalise because it's going 'against' the direction of the
11001100
# linked list (otherwise you will end up with identities scattered throughout)
1101-
_init(o::ComposedFunction{Outer,Inner}) where {Outer,Inner} = normalise(_init(o.outer) o.inner)
1101+
function _init(o::ComposedFunction{Outer,Inner}) where {Outer,Inner}
1102+
return normalise(_init(o.outer) o.inner)
1103+
end
11021104
_init(::Accessors.PropertyLens) = identity
11031105
_init(::Accessors.IndexLens) = identity
11041106
_init(::typeof(identity)) = identity

test/hasvalue.jl

Lines changed: 52 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,56 @@ end
126126
using Distributions
127127
using LinearAlgebra
128128

129-
d = Dict(@varname(x[1]) => 1.0, @varname(x[2]) => 2.0)
130-
@test hasvalue(d, @varname(x), MvNormal(zeros(2), I))
131-
@test !hasvalue(d, @varname(x), MvNormal(zeros(3), I))
129+
@testset "univariate" begin
130+
d = Dict(@varname(x) => 1.0, @varname(y) => [[2.0]])
131+
@test hasvalue(d, @varname(x), Normal())
132+
@test getvalue(d, @varname(x), Normal()) == 1.0
133+
@test hasvalue(d, @varname(y[1][1]), Normal())
134+
@test getvalue(d, @varname(y[1][1]), Normal()) == 2.0
135+
end
136+
137+
@testset "multivariate + matrix" begin
138+
d = Dict(@varname(x[1]) => 1.0, @varname(x[2]) => 2.0)
139+
@test hasvalue(d, @varname(x), MvNormal(zeros(1), I))
140+
@test getvalue(d, @varname(x), MvNormal(zeros(1), I)) == [1.0]
141+
@test hasvalue(d, @varname(x), MvNormal(zeros(2), I))
142+
@test getvalue(d, @varname(x), MvNormal(zeros(2), I)) == [1.0, 2.0]
143+
@test !hasvalue(d, @varname(x), MvNormal(zeros(3), I))
144+
@test_throws ErrorException hasvalue(
145+
d, @varname(x), MvNormal(zeros(3), I); error_on_incomplete=true
146+
)
147+
# If none of the varnames match, it should just return false instead of erroring
148+
@test !hasvalue(d, @varname(y), MvNormal(zeros(2), I); error_on_incomplete=true)
149+
end
150+
151+
@testset "LKJCholesky :upside_down_smile:" begin
152+
# yes, this isn't a valid Cholesky sample, but whatever
153+
d = Dict(
154+
@varname(x.L[1, 1]) => 1.0,
155+
@varname(x.L[2, 1]) => 2.0,
156+
@varname(x.L[2, 2]) => 3.0,
157+
)
158+
@test hasvalue(d, @varname(x), LKJCholesky(2, 1.0))
159+
@test getvalue(d, @varname(x), LKJCholesky(2, 1.0)) ==
160+
Cholesky(LowerTriangular([1.0 0.0; 2.0 3.0]))
161+
@test !hasvalue(d, @varname(x), LKJCholesky(3, 1.0))
162+
@test_throws ErrorException hasvalue(
163+
d, @varname(x), LKJCholesky(3, 1.0); error_on_incomplete=true
164+
)
165+
@test !hasvalue(d, @varname(y), LKJCholesky(3, 1.0); error_on_incomplete=true)
166+
167+
d = Dict(
168+
@varname(x.U[1, 1]) => 1.0,
169+
@varname(x.U[1, 2]) => 2.0,
170+
@varname(x.U[2, 2]) => 3.0,
171+
)
172+
@test hasvalue(d, @varname(x), LKJCholesky(2, 1.0, :U))
173+
@test getvalue(d, @varname(x), LKJCholesky(2, 1.0, :U)) ==
174+
Cholesky(UpperTriangular([1.0 2.0; 0.0 3.0]))
175+
@test !hasvalue(d, @varname(x), LKJCholesky(3, 1.0, :U))
176+
@test_throws ErrorException hasvalue(
177+
d, @varname(x), LKJCholesky(3, 1.0, :U); error_on_incomplete=true
178+
)
179+
@test !hasvalue(d, @varname(y), LKJCholesky(3, 1.0, :U); error_on_incomplete=true)
180+
end
132181
end

0 commit comments

Comments
 (0)