Skip to content

Move hasvalue and getvalue from DynamicPPL; implement extra Distributions-based methods #125

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions HISTORY.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,16 @@
## 0.12.1

Minimum compatibility has been bumped to Julia 1.10.

Added the new functions `hasvalue(container::T, ::VarName[, ::Distribution])` and `getvalue(container::T, ::VarName[, ::Distribution])`, where `T` is either `NamedTuple` or `AbstractDict{<:VarName}`.

These functions check whether a given `VarName` has a value in the given `NamedTuple` or `AbstractDict`, and return the value if it exists.

The optional `Distribution` argument allows one to reconstruct a full value from its component indices.
For example, if `container` has `x[1]` and `x[2]`, then `hasvalue(container, @varname(x), dist)` will return true if `size(dist) == (2,)` (for example, `MvNormal(zeros(2), I)`).

These functions (without the `Distribution` argument) were previously in DynamicPPL.jl (albeit unexported).

## 0.12.0

### VarName constructors
Expand Down
13 changes: 11 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ uuid = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf"
keywords = ["probablistic programming"]
license = "MIT"
desc = "Common interfaces for probabilistic programming"
version = "0.12.0"
version = "0.12.1"

[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
Expand All @@ -13,11 +13,20 @@ JSON = "682c06a0-de6a-54ab-a142-c8b1cf79cde6"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"

[weakdeps]
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

[extensions]
AbstractPPLDistributionsExt = ["Distributions", "LinearAlgebra"]

[compat]
AbstractMCMC = "2, 3, 4, 5"
Accessors = "0.1"
DensityInterface = "0.4"
Distributions = "0.25"
LinearAlgebra = "<0.0.1, 1.10"
JSON = "0.19 - 0.21"
Random = "1.6"
StatsBase = "0.32, 0.33, 0.34"
julia = "~1.6.6, 1.7.3"
julia = "1.10"
3 changes: 3 additions & 0 deletions docs/Project.toml
Original file line number Diff line number Diff line change
@@ -1,2 +1,5 @@
[deps]
AbstractPPL = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
4 changes: 3 additions & 1 deletion docs/make.jl
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
using Documenter
using AbstractPPL
# trigger DistributionsExt loading
using Distributions, LinearAlgebra

# Doctest setup
DocMeta.setdocmeta!(AbstractPPL, :DocTestSetup, :(using AbstractPPL); recursive=true)

makedocs(;
sitename="AbstractPPL",
modules=[AbstractPPL],
modules=[AbstractPPL, Base.get_extension(AbstractPPL, :AbstractPPLDistributionsExt)],
pages=["Home" => "index.md", "API" => "api.md"],
checkdocs=:exports,
doctest=false,
Expand Down
7 changes: 7 additions & 0 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,13 @@ prefix
unprefix
```

## Extracting values corresponding to a VarName

```@docs
hasvalue
getvalue
```

## VarName serialisation

```@docs
Expand Down
276 changes: 276 additions & 0 deletions ext/AbstractPPLDistributionsExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,276 @@
module AbstractPPLDistributionsExt

using AbstractPPL: AbstractPPL, VarName, Accessors
using Distributions: Distributions
using LinearAlgebra: Cholesky, LowerTriangular, UpperTriangular

#=
This section is copied from Accessors.jl's documentation:
https://juliaobjects.github.io/Accessors.jl/stable/examples/custom_macros/

It defines a wrapper that, when called with `set`, mutates the original value
rather than returning a new value. We need this because the non-mutating optics
don't work for triangular matrices (and hence LKJCholesky): see
https://github.com/JuliaObjects/Accessors.jl/issues/203
=#
struct Lens!{L}
pure::L
end
(l::Lens!)(o) = l.pure(o)

Check warning on line 19 in ext/AbstractPPLDistributionsExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/AbstractPPLDistributionsExt.jl#L19

Added line #L19 was not covered by tests
function Accessors.set(o, l::Lens!{<:ComposedFunction}, val)
o_inner = l.pure.inner(o)
return Accessors.set(o_inner, Lens!(l.pure.outer), val)
end
function Accessors.set(o, l::Lens!{Accessors.PropertyLens{prop}}, val) where {prop}
setproperty!(o, prop, val)
return o

Check warning on line 26 in ext/AbstractPPLDistributionsExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/AbstractPPLDistributionsExt.jl#L24-L26

Added lines #L24 - L26 were not covered by tests
end
function Accessors.set(o, l::Lens!{<:Accessors.IndexLens}, val)
o[l.pure.indices...] = val
return o
end

"""
get_optics(dist::MultivariateDistribution)
get_optics(dist::MatrixDistribution)
get_optics(dist::LKJCholesky)

Return a complete set of optics for each element of the type returned by `rand(dist)`.
"""
function get_optics(
dist::Union{Distributions.MultivariateDistribution,Distributions.MatrixDistribution}
)
indices = CartesianIndices(size(dist))
return map(idx -> Accessors.IndexLens(idx.I), indices)
end
function get_optics(dist::Distributions.LKJCholesky)
is_up = dist.uplo == 'U'
cartesian_indices = filter(CartesianIndices(size(dist))) do cartesian_index
i, j = cartesian_index.I
is_up ? i <= j : i >= j
end
# there is an additional layer as we need to access `.L` or `.U` before we
# can index into it
field_lens = is_up ? (Accessors.@o _.U) : (Accessors.@o _.L)
return map(idx -> Accessors.IndexLens(idx.I) ∘ field_lens, cartesian_indices)
end

"""
make_empty_value(dist::MultivariateDistribution)
make_empty_value(dist::MatrixDistribution)
make_empty_value(dist::LKJCholesky)

Construct a fresh value filled with zeros that corresponds to the size of `dist`.

For all distributions that this function accepts, it should hold that
`o(make_empty_value(dist))` is zero for all `o` in `get_optics(dist)`.
"""
function make_empty_value(
dist::Union{Distributions.MultivariateDistribution,Distributions.MatrixDistribution}
)
return zeros(size(dist))
end
function make_empty_value(dist::Distributions.LKJCholesky)
if dist.uplo == 'U'
return Cholesky(UpperTriangular(zeros(size(dist))))
else
return Cholesky(LowerTriangular(zeros(size(dist))))
end
end

"""
hasvalue(
vals::Union{AbstractDict,NamedTuple},
vn::VarName,
dist::Distribution;
error_on_incomplete::Bool=false
)

Check if `vals` contains values for `vn` that is compatible with the
distribution `dist`.

This is a more general version of `hasvalue(vals, vn)`, in that even if
`vn` itself is not inside `vals`, it further checks if `vals` contains
sub-values of `vn` that can be used to reconstruct `vn` given `dist`.

The `error_on_incomplete` flag can be used to detect cases where _some_ of
the values needed for `vn` are present, but others are not. This may help
to detect invalid cases where the user has provided e.g. data of the wrong
shape.

Note that this check is only possible if a Dict is passed, because the key type
of a NamedTuple (i.e., Symbol) is not rich enough to carry indexing
information. If this method is called with a NamedTuple, it will just defer
to `hasvalue(vals, vn)`.

For example:

```jldoctest; setup=:(using Distributions, LinearAlgebra))
julia> d = Dict(@varname(x[1]) => 1.0, @varname(x[2]) => 2.0);

julia> hasvalue(d, @varname(x), MvNormal(zeros(2), I))
true

julia> hasvalue(d, @varname(x), MvNormal(zeros(3), I))
false

julia> hasvalue(d, @varname(x), MvNormal(zeros(3), I); error_on_incomplete=true)
ERROR: hasvalue: only partial values for `x` found in the values provided
[...]
```
"""
function AbstractPPL.hasvalue(

Check warning on line 122 in ext/AbstractPPLDistributionsExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/AbstractPPLDistributionsExt.jl#L122

Added line #L122 was not covered by tests
vals::NamedTuple,
vn::VarName,
dist::Distributions.Distribution;
error_on_incomplete::Bool=false,
)
# NamedTuples can't have such complicated hierarchies, so it's safe to
# defer to the simpler `hasvalue(vals, vn)`.
return hasvalue(vals, vn)

Check warning on line 130 in ext/AbstractPPLDistributionsExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/AbstractPPLDistributionsExt.jl#L130

Added line #L130 was not covered by tests
end
function AbstractPPL.hasvalue(

Check warning on line 132 in ext/AbstractPPLDistributionsExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/AbstractPPLDistributionsExt.jl#L132

Added line #L132 was not covered by tests
vals::AbstractDict,
vn::VarName,
dist::Distributions.Distribution;
error_on_incomplete::Bool=false,
)
@warn "`hasvalue(vals, vn, dist)` is not implemented for $(typeof(dist)); falling back to `hasvalue(vals, vn)`."
return AbstractPPL.hasvalue(vals, vn)

Check warning on line 139 in ext/AbstractPPLDistributionsExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/AbstractPPLDistributionsExt.jl#L138-L139

Added lines #L138 - L139 were not covered by tests
end
function AbstractPPL.hasvalue(
vals::AbstractDict,
vn::VarName,
::Distributions.UnivariateDistribution;
error_on_incomplete::Bool=false,
)
# TODO(penelopeysm): We could also implement a check for the type to catch
# invalid values. Unsure if that is worth it. It may be easier to just let
# the user handle it.
return AbstractPPL.hasvalue(vals, vn)
end
function AbstractPPL.hasvalue(
vals::AbstractDict{<:VarName},
vn::VarName{sym},
dist::Union{
Distributions.MultivariateDistribution,
Distributions.MatrixDistribution,
Distributions.LKJCholesky,
};
error_on_incomplete::Bool=false,
) where {sym}
# If `vn` is present as-is, then we are good
AbstractPPL.hasvalue(vals, vn) && return true
# If not, then we need to check inside `vals` to see if a subset of
# `vals` is enough to reconstruct `vn`. For example, if `vals` contains
# `x[1]` and `x[2]`, and `dist` is `MvNormal(zeros(2), I)`, then we
# can reconstruct `x`. If `dist` is `MvNormal(zeros(3), I)`, then we
# can't.
# To do this, we get the size of the distribution and iterate over all
# possible indices. If every index can be found in `subsumed_keys`, then we
# can return true.
optics = get_optics(dist)
original_optic = AbstractPPL.getoptic(vn)
expected_vns = map(o -> VarName{sym}(o ∘ original_optic), optics)
if all(sub_vn -> AbstractPPL.hasvalue(vals, sub_vn), expected_vns)
return true
else
if error_on_incomplete &&
any(sub_vn -> AbstractPPL.hasvalue(vals, sub_vn), expected_vns)
error("hasvalue: only partial values for `$vn` found in the values provided")
end
return false
end
end

"""
getvalue(
vals::Union{AbstractDict,NamedTuple},
vn::VarName,
dist::Distribution
)

Retrieve the value of `vn` from `vals`, using the distribution `dist` to
reconstruct the value if necessary.

This is a more general version of `getvalue(vals, vn)`, in that even if `vn`
itself is not inside `vals`, it can still reconstruct the value of `vn`
from sub-values of `vn` that are present in `vals`.

Note that this reconstruction is only possible if a Dict is passed, because the
key type of a NamedTuple (i.e., Symbol) is not rich enough to carry indexing
information. If this method is called with a NamedTuple, it will just defer
to `getvalue(vals, vn)`.

For example:

```jldoctest; setup=:(using Distributions, LinearAlgebra))
julia> d = Dict(@varname(x[1]) => 1.0, @varname(x[2]) => 2.0);

julia> getvalue(d, @varname(x), MvNormal(zeros(2), I))
2-element Vector{Float64}:
1.0
2.0

julia> # Use `hasvalue` to check for this case before calling `getvalue`.
getvalue(d, @varname(x), MvNormal(zeros(3), I))
ERROR: getvalue: `x` was not found in the values provided
[...]
```
"""
function AbstractPPL.getvalue(

Check warning on line 221 in ext/AbstractPPLDistributionsExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/AbstractPPLDistributionsExt.jl#L221

Added line #L221 was not covered by tests
vals::NamedTuple, vn::VarName, dist::Distributions.Distribution
)
# NamedTuples can't have such complicated hierarchies, so it's safe to
# defer to the simpler `getvalue(vals, vn)`.
return getvalue(vals, vn)

Check warning on line 226 in ext/AbstractPPLDistributionsExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/AbstractPPLDistributionsExt.jl#L226

Added line #L226 was not covered by tests
end
function AbstractPPL.getvalue(

Check warning on line 228 in ext/AbstractPPLDistributionsExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/AbstractPPLDistributionsExt.jl#L228

Added line #L228 was not covered by tests
vals::AbstractDict, vn::VarName, dist::Distributions.Distribution;
)
@warn "`getvalue(vals, vn, dist)` is not implemented for $(typeof(dist)); falling back to `getvalue(vals, vn)`."
return AbstractPPL.getvalue(vals, vn)

Check warning on line 232 in ext/AbstractPPLDistributionsExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/AbstractPPLDistributionsExt.jl#L231-L232

Added lines #L231 - L232 were not covered by tests
end
function AbstractPPL.getvalue(
vals::AbstractDict, vn::VarName, ::Distributions.UnivariateDistribution;
)
# TODO(penelopeysm): We could also implement a check for the type to catch
# invalid values. Unsure if that is worth it. It may be easier to just let
# the user handle it.
return AbstractPPL.getvalue(vals, vn)
end
function AbstractPPL.getvalue(
vals::AbstractDict{<:VarName},
vn::VarName{sym},
dist::Union{
Distributions.MultivariateDistribution,
Distributions.MatrixDistribution,
Distributions.LKJCholesky,
};
) where {sym}
# If `vn` is present as-is, then we can just return that
AbstractPPL.hasvalue(vals, vn) && return AbstractPPL.getvalue(vals, vn)
# If not, then we need to start looking inside `vals`, in exactly the
# same way we did for `hasvalue`.
optics = get_optics(dist)
original_optic = AbstractPPL.getoptic(vn)
expected_vns = map(o -> VarName{sym}(o ∘ original_optic), optics)
if all(sub_vn -> AbstractPPL.hasvalue(vals, sub_vn), expected_vns)
# Reconstruct the value index by index.
value = make_empty_value(dist)
for (o, sub_vn) in zip(optics, expected_vns)
# Retrieve the value of this given index
sub_value = AbstractPPL.getvalue(vals, sub_vn)
# Set it inside the value we're reconstructing.
# Note: `o` is normally non-mutating. We have to wrap it in `Lens!`
# to make it mutating, because Cholesky distributions are broken
# by https://github.com/JuliaObjects/Accessors.jl/issues/203.
Accessors.set(value, Lens!(o), sub_value)
end
return value
else
error("getvalue: $(vn) was not found in the values provided")

Check warning on line 272 in ext/AbstractPPLDistributionsExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/AbstractPPLDistributionsExt.jl#L272

Added line #L272 was not covered by tests
end
end

end
5 changes: 4 additions & 1 deletion src/AbstractPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@ export VarName,
varname_to_string,
string_to_varname,
prefix,
unprefix
unprefix,
getvalue,
hasvalue

# Abstract model functions
export AbstractProbabilisticProgram,
Expand All @@ -29,5 +31,6 @@ include("varname.jl")
include("abstractmodeltrace.jl")
include("abstractprobprog.jl")
include("evaluate.jl")
include("hasvalue.jl")

end # module
Loading
Loading