diff --git a/HISTORY.md b/HISTORY.md index 65a678f..551d94a 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -1,3 +1,16 @@ +## 0.12.1 + +Minimum compatibility has been bumped to Julia 1.10. + +Added the new functions `hasvalue(container::T, ::VarName[, ::Distribution])` and `getvalue(container::T, ::VarName[, ::Distribution])`, where `T` is either `NamedTuple` or `AbstractDict{<:VarName}`. + +These functions check whether a given `VarName` has a value in the given `NamedTuple` or `AbstractDict`, and return the value if it exists. + +The optional `Distribution` argument allows one to reconstruct a full value from its component indices. +For example, if `container` has `x[1]` and `x[2]`, then `hasvalue(container, @varname(x), dist)` will return true if `size(dist) == (2,)` (for example, `MvNormal(zeros(2), I)`). + +These functions (without the `Distribution` argument) were previously in DynamicPPL.jl (albeit unexported). + ## 0.12.0 ### VarName constructors diff --git a/Project.toml b/Project.toml index ecccd56..d6c5c15 100644 --- a/Project.toml +++ b/Project.toml @@ -3,7 +3,7 @@ uuid = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf" keywords = ["probablistic programming"] license = "MIT" desc = "Common interfaces for probabilistic programming" -version = "0.12.0" +version = "0.12.1" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" @@ -13,11 +13,20 @@ JSON = "682c06a0-de6a-54ab-a142-c8b1cf79cde6" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" +[weakdeps] +Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" + +[extensions] +AbstractPPLDistributionsExt = ["Distributions", "LinearAlgebra"] + [compat] AbstractMCMC = "2, 3, 4, 5" Accessors = "0.1" DensityInterface = "0.4" +Distributions = "0.25" +LinearAlgebra = "<0.0.1, 1.10" JSON = "0.19 - 0.21" Random = "1.6" StatsBase = "0.32, 0.33, 0.34" -julia = "~1.6.6, 1.7.3" +julia = "1.10" diff --git a/docs/Project.toml b/docs/Project.toml index dfa65cd..15b2ec4 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -1,2 +1,5 @@ [deps] +AbstractPPL = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf" +Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" diff --git a/docs/make.jl b/docs/make.jl index 33bf21b..d3dbe83 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -1,12 +1,14 @@ using Documenter using AbstractPPL +# trigger DistributionsExt loading +using Distributions, LinearAlgebra # Doctest setup DocMeta.setdocmeta!(AbstractPPL, :DocTestSetup, :(using AbstractPPL); recursive=true) makedocs(; sitename="AbstractPPL", - modules=[AbstractPPL], + modules=[AbstractPPL, Base.get_extension(AbstractPPL, :AbstractPPLDistributionsExt)], pages=["Home" => "index.md", "API" => "api.md"], checkdocs=:exports, doctest=false, diff --git a/docs/src/api.md b/docs/src/api.md index e7705da..9c05c47 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -21,6 +21,13 @@ prefix unprefix ``` +## Extracting values corresponding to a VarName + +```@docs +hasvalue +getvalue +``` + ## VarName serialisation ```@docs diff --git a/ext/AbstractPPLDistributionsExt.jl b/ext/AbstractPPLDistributionsExt.jl new file mode 100644 index 0000000..8eb920a --- /dev/null +++ b/ext/AbstractPPLDistributionsExt.jl @@ -0,0 +1,276 @@ +module AbstractPPLDistributionsExt + +using AbstractPPL: AbstractPPL, VarName, Accessors +using Distributions: Distributions +using LinearAlgebra: Cholesky, LowerTriangular, UpperTriangular + +#= +This section is copied from Accessors.jl's documentation: +https://juliaobjects.github.io/Accessors.jl/stable/examples/custom_macros/ + +It defines a wrapper that, when called with `set`, mutates the original value +rather than returning a new value. We need this because the non-mutating optics +don't work for triangular matrices (and hence LKJCholesky): see +https://github.com/JuliaObjects/Accessors.jl/issues/203 +=# +struct Lens!{L} + pure::L +end +(l::Lens!)(o) = l.pure(o) +function Accessors.set(o, l::Lens!{<:ComposedFunction}, val) + o_inner = l.pure.inner(o) + return Accessors.set(o_inner, Lens!(l.pure.outer), val) +end +function Accessors.set(o, l::Lens!{Accessors.PropertyLens{prop}}, val) where {prop} + setproperty!(o, prop, val) + return o +end +function Accessors.set(o, l::Lens!{<:Accessors.IndexLens}, val) + o[l.pure.indices...] = val + return o +end + +""" + get_optics(dist::MultivariateDistribution) + get_optics(dist::MatrixDistribution) + get_optics(dist::LKJCholesky) + +Return a complete set of optics for each element of the type returned by `rand(dist)`. +""" +function get_optics( + dist::Union{Distributions.MultivariateDistribution,Distributions.MatrixDistribution} +) + indices = CartesianIndices(size(dist)) + return map(idx -> Accessors.IndexLens(idx.I), indices) +end +function get_optics(dist::Distributions.LKJCholesky) + is_up = dist.uplo == 'U' + cartesian_indices = filter(CartesianIndices(size(dist))) do cartesian_index + i, j = cartesian_index.I + is_up ? i <= j : i >= j + end + # there is an additional layer as we need to access `.L` or `.U` before we + # can index into it + field_lens = is_up ? (Accessors.@o _.U) : (Accessors.@o _.L) + return map(idx -> Accessors.IndexLens(idx.I) ∘ field_lens, cartesian_indices) +end + +""" + make_empty_value(dist::MultivariateDistribution) + make_empty_value(dist::MatrixDistribution) + make_empty_value(dist::LKJCholesky) + +Construct a fresh value filled with zeros that corresponds to the size of `dist`. + +For all distributions that this function accepts, it should hold that +`o(make_empty_value(dist))` is zero for all `o` in `get_optics(dist)`. +""" +function make_empty_value( + dist::Union{Distributions.MultivariateDistribution,Distributions.MatrixDistribution} +) + return zeros(size(dist)) +end +function make_empty_value(dist::Distributions.LKJCholesky) + if dist.uplo == 'U' + return Cholesky(UpperTriangular(zeros(size(dist)))) + else + return Cholesky(LowerTriangular(zeros(size(dist)))) + end +end + +""" + hasvalue( + vals::Union{AbstractDict,NamedTuple}, + vn::VarName, + dist::Distribution; + error_on_incomplete::Bool=false + ) + +Check if `vals` contains values for `vn` that is compatible with the +distribution `dist`. + +This is a more general version of `hasvalue(vals, vn)`, in that even if +`vn` itself is not inside `vals`, it further checks if `vals` contains +sub-values of `vn` that can be used to reconstruct `vn` given `dist`. + +The `error_on_incomplete` flag can be used to detect cases where _some_ of +the values needed for `vn` are present, but others are not. This may help +to detect invalid cases where the user has provided e.g. data of the wrong +shape. + +Note that this check is only possible if a Dict is passed, because the key type +of a NamedTuple (i.e., Symbol) is not rich enough to carry indexing +information. If this method is called with a NamedTuple, it will just defer +to `hasvalue(vals, vn)`. + +For example: + +```jldoctest; setup=:(using Distributions, LinearAlgebra)) +julia> d = Dict(@varname(x[1]) => 1.0, @varname(x[2]) => 2.0); + +julia> hasvalue(d, @varname(x), MvNormal(zeros(2), I)) +true + +julia> hasvalue(d, @varname(x), MvNormal(zeros(3), I)) +false + +julia> hasvalue(d, @varname(x), MvNormal(zeros(3), I); error_on_incomplete=true) +ERROR: hasvalue: only partial values for `x` found in the values provided +[...] +``` +""" +function AbstractPPL.hasvalue( + vals::NamedTuple, + vn::VarName, + dist::Distributions.Distribution; + error_on_incomplete::Bool=false, +) + # NamedTuples can't have such complicated hierarchies, so it's safe to + # defer to the simpler `hasvalue(vals, vn)`. + return hasvalue(vals, vn) +end +function AbstractPPL.hasvalue( + vals::AbstractDict, + vn::VarName, + dist::Distributions.Distribution; + error_on_incomplete::Bool=false, +) + @warn "`hasvalue(vals, vn, dist)` is not implemented for $(typeof(dist)); falling back to `hasvalue(vals, vn)`." + return AbstractPPL.hasvalue(vals, vn) +end +function AbstractPPL.hasvalue( + vals::AbstractDict, + vn::VarName, + ::Distributions.UnivariateDistribution; + error_on_incomplete::Bool=false, +) + # TODO(penelopeysm): We could also implement a check for the type to catch + # invalid values. Unsure if that is worth it. It may be easier to just let + # the user handle it. + return AbstractPPL.hasvalue(vals, vn) +end +function AbstractPPL.hasvalue( + vals::AbstractDict{<:VarName}, + vn::VarName{sym}, + dist::Union{ + Distributions.MultivariateDistribution, + Distributions.MatrixDistribution, + Distributions.LKJCholesky, + }; + error_on_incomplete::Bool=false, +) where {sym} + # If `vn` is present as-is, then we are good + AbstractPPL.hasvalue(vals, vn) && return true + # If not, then we need to check inside `vals` to see if a subset of + # `vals` is enough to reconstruct `vn`. For example, if `vals` contains + # `x[1]` and `x[2]`, and `dist` is `MvNormal(zeros(2), I)`, then we + # can reconstruct `x`. If `dist` is `MvNormal(zeros(3), I)`, then we + # can't. + # To do this, we get the size of the distribution and iterate over all + # possible indices. If every index can be found in `subsumed_keys`, then we + # can return true. + optics = get_optics(dist) + original_optic = AbstractPPL.getoptic(vn) + expected_vns = map(o -> VarName{sym}(o ∘ original_optic), optics) + if all(sub_vn -> AbstractPPL.hasvalue(vals, sub_vn), expected_vns) + return true + else + if error_on_incomplete && + any(sub_vn -> AbstractPPL.hasvalue(vals, sub_vn), expected_vns) + error("hasvalue: only partial values for `$vn` found in the values provided") + end + return false + end +end + +""" + getvalue( + vals::Union{AbstractDict,NamedTuple}, + vn::VarName, + dist::Distribution + ) + +Retrieve the value of `vn` from `vals`, using the distribution `dist` to +reconstruct the value if necessary. + +This is a more general version of `getvalue(vals, vn)`, in that even if `vn` +itself is not inside `vals`, it can still reconstruct the value of `vn` +from sub-values of `vn` that are present in `vals`. + +Note that this reconstruction is only possible if a Dict is passed, because the +key type of a NamedTuple (i.e., Symbol) is not rich enough to carry indexing +information. If this method is called with a NamedTuple, it will just defer +to `getvalue(vals, vn)`. + +For example: + +```jldoctest; setup=:(using Distributions, LinearAlgebra)) +julia> d = Dict(@varname(x[1]) => 1.0, @varname(x[2]) => 2.0); + +julia> getvalue(d, @varname(x), MvNormal(zeros(2), I)) +2-element Vector{Float64}: + 1.0 + 2.0 + +julia> # Use `hasvalue` to check for this case before calling `getvalue`. + getvalue(d, @varname(x), MvNormal(zeros(3), I)) +ERROR: getvalue: `x` was not found in the values provided +[...] +``` +""" +function AbstractPPL.getvalue( + vals::NamedTuple, vn::VarName, dist::Distributions.Distribution +) + # NamedTuples can't have such complicated hierarchies, so it's safe to + # defer to the simpler `getvalue(vals, vn)`. + return getvalue(vals, vn) +end +function AbstractPPL.getvalue( + vals::AbstractDict, vn::VarName, dist::Distributions.Distribution; +) + @warn "`getvalue(vals, vn, dist)` is not implemented for $(typeof(dist)); falling back to `getvalue(vals, vn)`." + return AbstractPPL.getvalue(vals, vn) +end +function AbstractPPL.getvalue( + vals::AbstractDict, vn::VarName, ::Distributions.UnivariateDistribution; +) + # TODO(penelopeysm): We could also implement a check for the type to catch + # invalid values. Unsure if that is worth it. It may be easier to just let + # the user handle it. + return AbstractPPL.getvalue(vals, vn) +end +function AbstractPPL.getvalue( + vals::AbstractDict{<:VarName}, + vn::VarName{sym}, + dist::Union{ + Distributions.MultivariateDistribution, + Distributions.MatrixDistribution, + Distributions.LKJCholesky, + }; +) where {sym} + # If `vn` is present as-is, then we can just return that + AbstractPPL.hasvalue(vals, vn) && return AbstractPPL.getvalue(vals, vn) + # If not, then we need to start looking inside `vals`, in exactly the + # same way we did for `hasvalue`. + optics = get_optics(dist) + original_optic = AbstractPPL.getoptic(vn) + expected_vns = map(o -> VarName{sym}(o ∘ original_optic), optics) + if all(sub_vn -> AbstractPPL.hasvalue(vals, sub_vn), expected_vns) + # Reconstruct the value index by index. + value = make_empty_value(dist) + for (o, sub_vn) in zip(optics, expected_vns) + # Retrieve the value of this given index + sub_value = AbstractPPL.getvalue(vals, sub_vn) + # Set it inside the value we're reconstructing. + # Note: `o` is normally non-mutating. We have to wrap it in `Lens!` + # to make it mutating, because Cholesky distributions are broken + # by https://github.com/JuliaObjects/Accessors.jl/issues/203. + Accessors.set(value, Lens!(o), sub_value) + end + return value + else + error("getvalue: $(vn) was not found in the values provided") + end +end + +end diff --git a/src/AbstractPPL.jl b/src/AbstractPPL.jl index 40da231..dd495db 100644 --- a/src/AbstractPPL.jl +++ b/src/AbstractPPL.jl @@ -16,7 +16,9 @@ export VarName, varname_to_string, string_to_varname, prefix, - unprefix + unprefix, + getvalue, + hasvalue # Abstract model functions export AbstractProbabilisticProgram, @@ -29,5 +31,6 @@ include("varname.jl") include("abstractmodeltrace.jl") include("abstractprobprog.jl") include("evaluate.jl") +include("hasvalue.jl") end # module diff --git a/src/hasvalue.jl b/src/hasvalue.jl new file mode 100644 index 0000000..1e44c37 --- /dev/null +++ b/src/hasvalue.jl @@ -0,0 +1,262 @@ +""" + canview(optic, container) + +Return `true` if `optic` can be used to view `container`, and `false` otherwise. + +# Examples +```jldoctest; setup=:(using Accessors) +julia> AbstractPPL.canview(@o(_.a), (a = 1.0, )) +true + +julia> AbstractPPL.canview(@o(_.a), (b = 1.0, )) # property `a` does not exist +false + +julia> AbstractPPL.canview(@o(_.a[1]), (a = [1.0, 2.0], )) +true + +julia> AbstractPPL.canview(@o(_.a[3]), (a = [1.0, 2.0], )) # out of bounds +false +``` +""" +canview(optic, container) = false +canview(::typeof(identity), _) = true +function canview(::Accessors.PropertyLens{field}, x) where {field} + return hasproperty(x, field) +end + +# `IndexLens`: only relevant if `x` supports indexing. +canview(optic::Accessors.IndexLens, x) = false +function canview(optic::Accessors.IndexLens, x::AbstractArray) + return checkbounds(Bool, x, optic.indices...) +end + +# `ComposedFunction`: check that we can view `.inner` and `.outer`, but using +# value extracted using `.inner`. +function canview(optic::ComposedFunction, x) + return canview(optic.inner, x) && canview(optic.outer, optic.inner(x)) +end + +""" + getvalue(vals::NamedTuple, vn::VarName) + getvalue(vals::AbstractDict{<:VarName}, vn::VarName) + +Return the value(s) in `vals` represented by `vn`. + +# Examples + +For `NamedTuple`: + +```jldoctest +julia> vals = (x = [1.0],); + +julia> getvalue(vals, @varname(x)) # same as `getindex` +1-element Vector{Float64}: + 1.0 + +julia> getvalue(vals, @varname(x[1])) # different from `getindex` +1.0 + +julia> getvalue(vals, @varname(x[2])) +ERROR: getvalue: x[2] was not found in the values provided +[...] +``` + +For `AbstractDict`: + +```jldoctest +julia> vals = Dict(@varname(x) => [1.0]); + +julia> getvalue(vals, @varname(x)) # same as `getindex` +1-element Vector{Float64}: + 1.0 + +julia> getvalue(vals, @varname(x[1])) # different from `getindex` +1.0 + +julia> getvalue(vals, @varname(x[2])) +ERROR: getvalue: x[2] was not found in the values provided +[...] +``` + +In the `AbstractDict` case we can also have keys such as `v[1]`: + +```jldoctest +julia> vals = Dict(@varname(x[1]) => [1.0,]); + +julia> getvalue(vals, @varname(x[1])) # same as `getindex` +1-element Vector{Float64}: + 1.0 + +julia> getvalue(vals, @varname(x[1][1])) # different from `getindex` +1.0 + +julia> getvalue(vals, @varname(x[1][2])) +ERROR: getvalue: x[1][2] was not found in the values provided +[...] + +julia> getvalue(vals, @varname(x[2][1])) +ERROR: getvalue: x[2][1] was not found in the values provided +[...] +``` + +Dictionaries can present ambiguous cases where the same variable is specified +twice at different levels. In such a situation, `getvalue` attempts to find an +exact match, and if that fails it returns the value with the most specific key. + +!!! note + It is the user's responsibility to avoid such cases by ensuring that the + dictionary passed in does not contain the same value specified multiple + times. + +```jldoctest +julia> vals = Dict(@varname(x) => [[1.0]], @varname(x[1]) => [2.0]); + +julia> # Here, the `x[1]` key is not used because `x` is an exact match. + getvalue(vals, @varname(x)) +1-element Vector{Vector{Float64}}: + [1.0] + +julia> # Likewise, the `x` key is not used because `x[1]` is an exact match. + getvalue(vals, @varname(x[1])) +1-element Vector{Float64}: + 2.0 + +julia> # No exact match, so the most specific key, i.e. `x[1]`, is used. + getvalue(vals, @varname(x[1][1])) +2.0 +``` +""" +function getvalue(vals::NamedTuple, vn::VarName{sym}) where {sym} + optic = getoptic(vn) + if haskey(vals, sym) && canview(optic, getproperty(vals, sym)) + return optic(vals[sym]) + else + error("getvalue: $(vn) was not found in the values provided") + end +end +# For the Dict case, it is more complicated. There are two cases: +# 1. `vn` itself is already a key of `vals` (the easy case) +# 2. `vn` is not a key of `vals`, but some parent of `vn` is a key of `vals` +# (the harder case). For example, if `vn` is `x[1][2]`, then we need to +# check if either `x` or `x[1]` is a key of `vals`, and if so, whether +# we can index into the corresponding value. +function getvalue(vals::AbstractDict{<:VarName}, vn::VarName{sym}) where {sym} + # First we check if `vn` is present as is. + haskey(vals, vn) && return vals[vn] + + # Otherwise, we start by testing the `vn` one level up (e.g., if `vn` is + # `x[1][2]`, we start by checking if `x[1]` is present, then `x`). We will + # then keep removing optics from `test_optic`, either until we find a key + # that is present, or until we run out of optics to test (which happens + # after getoptic(test_vn) == identity). + o = getoptic(vn) + test_vn = VarName{sym}(_init(o)) + test_optic = _last(o) + + while true + if haskey(vals, test_vn) && canview(test_optic, vals[test_vn]) + return test_optic(vals[test_vn]) + else + # Try to move the outermost optic from test_vn into test_optic. + # If test_vn is already an identity, we can't, so we stop. + o = getoptic(test_vn) + o == identity && error("getvalue: $(vn) was not found in the values provided") + test_vn = VarName{sym}(_init(o)) + test_optic = normalise(_last(o) ∘ test_optic) + end + end +end + +""" + hasvalue(vals::NamedTuple, vn::VarName) + hasvalue(vals::AbstractDict{<:VarName}, vn::VarName) + +Determine whether `vals` contains a value for a given `vn`. + +# Examples +With `x` as a `NamedTuple`: + +```jldoctest +julia> hasvalue((x = 1.0, ), @varname(x)) +true + +julia> hasvalue((x = 1.0, ), @varname(x[1])) +false + +julia> hasvalue((x = [1.0],), @varname(x)) +true + +julia> hasvalue((x = [1.0],), @varname(x[1])) +true + +julia> hasvalue((x = [1.0],), @varname(x[2])) +false +``` + +With `x` as a `AbstractDict`: + +```jldoctest +julia> hasvalue(Dict(@varname(x) => 1.0, ), @varname(x)) +true + +julia> hasvalue(Dict(@varname(x) => 1.0, ), @varname(x[1])) +false + +julia> hasvalue(Dict(@varname(x) => [1.0]), @varname(x)) +true + +julia> hasvalue(Dict(@varname(x) => [1.0]), @varname(x[1])) +true + +julia> hasvalue(Dict(@varname(x) => [1.0]), @varname(x[2])) +false +``` + +In the `AbstractDict` case we can also have keys such as `v[1]`: + +```jldoctest +julia> vals = Dict(@varname(x[1]) => [1.0,]); + +julia> hasvalue(vals, @varname(x[1])) # same as `haskey` +true + +julia> hasvalue(vals, @varname(x[1][1])) # different from `haskey` +true + +julia> hasvalue(vals, @varname(x[1][2])) +false + +julia> hasvalue(vals, @varname(x[2][1])) +false +``` +""" +function hasvalue(vals::NamedTuple, vn::VarName{sym}) where {sym} + return haskey(vals, sym) && canview(getoptic(vn), getproperty(vals, sym)) +end +function hasvalue(vals::AbstractDict{<:VarName}, vn::VarName{sym}) where {sym} + # First we check if `vn` is present as is. + haskey(vals, vn) && return true + + # Otherwise, we start by testing the `vn` one level up (e.g., if `vn` is + # `x[1][2]`, we start by checking if `x[1]` is present, then `x`). We will + # then keep removing optics from `test_optic`, either until we find a key + # that is present, or until we run out of optics to test (which happens + # after getoptic(test_vn) == identity). + o = getoptic(vn) + test_vn = VarName{sym}(_init(o)) + test_optic = _last(o) + + while true + if haskey(vals, test_vn) && canview(test_optic, vals[test_vn]) + return true + else + # Try to move the outermost optic from test_vn into test_optic. + # If test_vn is already an identity, we can't, so we stop. + o = getoptic(test_vn) + o == identity && return false + test_vn = VarName{sym}(_init(o)) + test_optic = normalise(_last(o) ∘ test_optic) + end + end + return false +end diff --git a/src/varname.jl b/src/varname.jl index 83f6f6f..fca1004 100644 --- a/src/varname.jl +++ b/src/varname.jl @@ -963,69 +963,147 @@ string_to_varname(str::AbstractString) = dict_to_varname(JSON.parse(str)) ### Prefixing and unprefixing """ - _strip_identity(optic) + _head(optic) -Remove identity lenses from composed optics. -""" -_strip_identity(::Base.ComposedFunction{typeof(identity),typeof(identity)}) = identity -function _strip_identity(o::Base.ComposedFunction{Outer,typeof(identity)}) where {Outer} - return _strip_identity(o.outer) -end -function _strip_identity(o::Base.ComposedFunction{typeof(identity),Inner}) where {Inner} - return _strip_identity(o.inner) -end -_strip_identity(o::Base.ComposedFunction) = o -_strip_identity(o::Accessors.PropertyLens) = o -_strip_identity(o::Accessors.IndexLens) = o -_strip_identity(o::typeof(identity)) = o +Get the innermost layer of an optic. -""" - _inner(optic) +For all (normalised) optics, we have that `normalise(_tail(optic) ∘ +_head(optic) == optic)`. -Get the innermost (non-identity) layer of an optic. +!!! note + Does not perform optic normalisation on the input. You may wish to call + `normalise(optic)` before using this function if the optic you are passing + was not obtained from a VarName. ```jldoctest; setup=:(using Accessors) -julia> AbstractPPL._inner(Accessors.@o _.a.b.c) +julia> AbstractPPL._head(Accessors.@o _.a.b.c) (@o _.a) -julia> AbstractPPL._inner(Accessors.@o _[1][2][3]) +julia> AbstractPPL._head(Accessors.@o _[1][2][3]) (@o _[1]) -julia> AbstractPPL._inner(Accessors.@o _) +julia> AbstractPPL._head(Accessors.@o _.a) +(@o _.a) + +julia> AbstractPPL._head(Accessors.@o _[1]) +(@o _[1]) + +julia> AbstractPPL._head(Accessors.@o _) identity (generic function with 1 method) ``` """ -_inner(o::Base.ComposedFunction{Outer,Inner}) where {Outer,Inner} = o.inner -_inner(o::Accessors.PropertyLens) = o -_inner(o::Accessors.IndexLens) = o -_inner(o::typeof(identity)) = o +_head(o::ComposedFunction{Outer,Inner}) where {Outer,Inner} = o.inner +_head(o::Accessors.PropertyLens) = o +_head(o::Accessors.IndexLens) = o +_head(::typeof(identity)) = identity """ - _outer(optic) + _tail(optic) + +Get everything but the innermost layer of an optic. -Get the outer layer of an optic. +For all (normalised) optics, we have that `normalise(_tail(optic) ∘ +_head(optic) == optic)`. + +!!! note + Does not perform optic normalisation on the input. You may wish to call + `normalise(optic)` before using this function if the optic you are passing + was not obtained from a VarName. ```jldoctest; setup=:(using Accessors) -julia> AbstractPPL._outer(Accessors.@o _.a.b.c) +julia> AbstractPPL._tail(Accessors.@o _.a.b.c) (@o _.b.c) -julia> AbstractPPL._outer(Accessors.@o _[1][2][3]) +julia> AbstractPPL._tail(Accessors.@o _[1][2][3]) (@o _[2][3]) -julia> AbstractPPL._outer(Accessors.@o _.a) +julia> AbstractPPL._tail(Accessors.@o _.a) identity (generic function with 1 method) -julia> AbstractPPL._outer(Accessors.@o _[1]) +julia> AbstractPPL._tail(Accessors.@o _[1]) identity (generic function with 1 method) -julia> AbstractPPL._outer(Accessors.@o _) +julia> AbstractPPL._tail(Accessors.@o _) +identity (generic function with 1 method) +``` +""" +_tail(o::ComposedFunction{Outer,Inner}) where {Outer,Inner} = o.outer +_tail(::Accessors.PropertyLens) = identity +_tail(::Accessors.IndexLens) = identity +_tail(::typeof(identity)) = identity + +""" + _last(optic) + +Get the outermost layer of an optic. + +For all (normalised) optics, we have that `normalise(_last(optic) ∘ +_init(optic)) == optic`. + +!!! note + Does not perform optic normalisation on the input. You may wish to call + `normalise(optic)` before using this function if the optic you are passing + was not obtained from a VarName. + +```jldoctest; setup=:(using Accessors) +julia> AbstractPPL._last(Accessors.@o _.a.b.c) +(@o _.c) + +julia> AbstractPPL._last(Accessors.@o _[1][2][3]) +(@o _[3]) + +julia> AbstractPPL._last(Accessors.@o _.a) +(@o _.a) + +julia> AbstractPPL._last(Accessors.@o _[1]) +(@o _[1]) + +julia> AbstractPPL._last(Accessors.@o _) identity (generic function with 1 method) ``` """ -_outer(o::Base.ComposedFunction{Outer,Inner}) where {Outer,Inner} = o.outer -_outer(::Accessors.PropertyLens) = identity -_outer(::Accessors.IndexLens) = identity -_outer(::typeof(identity)) = identity +_last(o::ComposedFunction{Outer,Inner}) where {Outer,Inner} = _last(o.outer) +_last(o::Accessors.PropertyLens) = o +_last(o::Accessors.IndexLens) = o +_last(::typeof(identity)) = identity + +""" + _init(optic) + +Get everything but the outermost layer of an optic. + +For all (normalised) optics, we have that `normalise(_last(optic) ∘ +_init(optic)) == optic`. + +!!! note + Does not perform optic normalisation on the input. You may wish to call + `normalise(optic)` before using this function if the optic you are passing + was not obtained from a VarName. + +```jldoctest; setup=:(using Accessors) +julia> AbstractPPL._init(Accessors.@o _.a.b.c) +(@o _.a.b) + +julia> AbstractPPL._init(Accessors.@o _[1][2][3]) +(@o _[1][2]) + +julia> AbstractPPL._init(Accessors.@o _.a) +identity (generic function with 1 method) + +julia> AbstractPPL._init(Accessors.@o _[1]) +identity (generic function with 1 method) + +julia> AbstractPPL._init(Accessors.@o _) +identity (generic function with 1 method) +""" +# This one needs normalise because it's going 'against' the direction of the +# linked list (otherwise you will end up with identities scattered throughout) +function _init(o::ComposedFunction{Outer,Inner}) where {Outer,Inner} + return normalise(_init(o.outer) ∘ o.inner) +end +_init(::Accessors.PropertyLens) = identity +_init(::Accessors.IndexLens) = identity +_init(::typeof(identity)) = identity """ optic_to_vn(optic) @@ -1058,11 +1136,11 @@ function optic_to_vn(::Accessors.PropertyLens{sym}) where {sym} return VarName{sym}() end function optic_to_vn( - o::Base.ComposedFunction{Outer,Accessors.PropertyLens{sym}} + o::ComposedFunction{Outer,Accessors.PropertyLens{sym}} ) where {Outer,sym} return VarName{sym}(o.outer) end -optic_to_vn(o::Base.ComposedFunction) = optic_to_vn(normalise(o)) +optic_to_vn(o::ComposedFunction) = optic_to_vn(normalise(o)) function optic_to_vn(@nospecialize(o)) msg = "optic_to_vn: could not convert optic `$o` to a VarName" throw(ArgumentError(msg)) @@ -1077,14 +1155,14 @@ function unprefix_optic(optic, optic_prefix) optic = normalise(optic) optic_prefix = normalise(optic_prefix) # strip one layer of the optic and check for equality - inner = _inner(optic) - inner_prefix = _inner(optic_prefix) - if inner != inner_prefix + head = _head(optic) + head_prefix = _head(optic_prefix) + if head != head_prefix msg = "could not remove prefix $(optic_prefix) from optic $(optic)" throw(ArgumentError(msg)) end # recurse - return unprefix_optic(_outer(optic), _outer(optic_prefix)) + return unprefix_optic(_tail(optic), _tail(optic_prefix)) end """ diff --git a/test/Project.toml b/test/Project.toml index dbc641b..23fb280 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,7 +1,9 @@ [deps] Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" +Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" InvertedIndices = "41ab1584-1d38-5bbf-9106-f11c6c58b48f" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" diff --git a/test/hasvalue.jl b/test/hasvalue.jl new file mode 100644 index 0000000..18d8f2e --- /dev/null +++ b/test/hasvalue.jl @@ -0,0 +1,181 @@ +@testset "base getvalue + hasvalue" begin + @testset "basic NamedTuple" begin + nt = (a=[1], b=2, c=(x=3,), d=[1.0 0.5; 0.5 1.0]) + @test hasvalue(nt, @varname(a)) + @test getvalue(nt, @varname(a)) == [1] + @test hasvalue(nt, @varname(a[1])) + @test getvalue(nt, @varname(a[1])) == 1 + @test hasvalue(nt, @varname(b)) + @test getvalue(nt, @varname(b)) == 2 + @test hasvalue(nt, @varname(c)) + @test getvalue(nt, @varname(c)) == (x=3,) + @test hasvalue(nt, @varname(c.x)) + @test getvalue(nt, @varname(c.x)) == 3 + @test hasvalue(nt, @varname(d)) + @test getvalue(nt, @varname(d)) == [1.0 0.5; 0.5 1.0] + @test hasvalue(nt, @varname(d[1, 1])) + @test getvalue(nt, @varname(d[1, 1])) == 1.0 + @test hasvalue(nt, @varname(d[1, 2])) + @test getvalue(nt, @varname(d[1, 2])) == 0.5 + @test hasvalue(nt, @varname(d[2, 1])) + @test getvalue(nt, @varname(d[2, 1])) == 0.5 + @test hasvalue(nt, @varname(d[2, 2])) + @test getvalue(nt, @varname(d[2, 2])) == 1.0 + @test hasvalue(nt, @varname(d[3])) # linear indexing works.... + @test getvalue(nt, @varname(d[3])) == 0.5 + @test !hasvalue(nt, @varname(nope)) + @test !hasvalue(nt, @varname(a[2])) + @test !hasvalue(nt, @varname(a[1][1])) + @test !hasvalue(nt, @varname(c.x[1])) + @test !hasvalue(nt, @varname(c.y)) + @test !hasvalue(nt, @varname(d[1, 3])) + @test !hasvalue(nt, @varname(d[3, :])) + end + + @testset "basic Dict" begin + # same tests as for NamedTuple + d = Dict( + @varname(a) => [1], + @varname(b) => 2, + @varname(c) => (x=3,), + @varname(d) => [1.0 0.5; 0.5 1.0], + ) + @test hasvalue(d, @varname(a)) + @test getvalue(d, @varname(a)) == [1] + @test hasvalue(d, @varname(a[1])) + @test getvalue(d, @varname(a[1])) == 1 + @test hasvalue(d, @varname(b)) + @test getvalue(d, @varname(b)) == 2 + @test hasvalue(d, @varname(c)) + @test getvalue(d, @varname(c)) == (x=3,) + @test hasvalue(d, @varname(c.x)) + @test getvalue(d, @varname(c.x)) == 3 + @test hasvalue(d, @varname(d)) + @test getvalue(d, @varname(d)) == [1.0 0.5; 0.5 1.0] + @test hasvalue(d, @varname(d[1, 1])) + @test getvalue(d, @varname(d[1, 1])) == 1.0 + @test hasvalue(d, @varname(d[1, 2])) + @test getvalue(d, @varname(d[1, 2])) == 0.5 + @test hasvalue(d, @varname(d[2, 1])) + @test getvalue(d, @varname(d[2, 1])) == 0.5 + @test hasvalue(d, @varname(d[2, 2])) + @test getvalue(d, @varname(d[2, 2])) == 1.0 + @test hasvalue(d, @varname(d[3])) # linear indexing works.... + @test getvalue(d, @varname(d[3])) == 0.5 + @test !hasvalue(d, @varname(nope)) + @test !hasvalue(d, @varname(a[2])) + @test !hasvalue(d, @varname(a[1][1])) + @test !hasvalue(d, @varname(c.x[1])) + @test !hasvalue(d, @varname(c.y)) + @test !hasvalue(d, @varname(d[1, 3])) + end + + @testset "Dict with non-identity varname keys" begin + d = Dict( + @varname(a[1]) => [1.0, 2.0], + @varname(b.x) => [3.0], + @varname(c[2]) => (a=4.0, b=5.0), + ) + @test hasvalue(d, @varname(a[1])) + @test getvalue(d, @varname(a[1])) == [1.0, 2.0] + @test hasvalue(d, @varname(a[1][1])) + @test getvalue(d, @varname(a[1][1])) == 1.0 + @test hasvalue(d, @varname(a[1][2])) + @test getvalue(d, @varname(a[1][2])) == 2.0 + @test hasvalue(d, @varname(b.x)) + @test getvalue(d, @varname(b.x)) == [3.0] + @test hasvalue(d, @varname(b.x[1])) + @test getvalue(d, @varname(b.x[1])) == 3.0 + @test hasvalue(d, @varname(c[2])) + @test getvalue(d, @varname(c[2])) == (a=4.0, b=5.0) + @test hasvalue(d, @varname(c[2].a)) + @test getvalue(d, @varname(c[2].a)) == 4.0 + @test hasvalue(d, @varname(c[2].b)) + @test getvalue(d, @varname(c[2].b)) == 5.0 + @test !hasvalue(d, @varname(a)) + @test !hasvalue(d, @varname(a[2])) + @test !hasvalue(d, @varname(b.y)) + @test !hasvalue(d, @varname(b.x[2])) + @test !hasvalue(d, @varname(c[1])) + @test !hasvalue(d, @varname(c[2].x)) + end + + @testset "Dict with redundancy" begin + d1 = Dict(@varname(x) => [[[[1.0]]]]) + d2 = Dict(@varname(x[1]) => [[[2.0]]]) + d3 = Dict(@varname(x[1][1]) => [[3.0]]) + d4 = Dict(@varname(x[1][1][1]) => [4.0]) + d5 = Dict(@varname(x[1][1][1][1]) => 5.0) + + d = Dict{VarName,Any}() + for (new_dict, expected_value) in + zip((d1, d2, d3, d4, d5), (1.0, 2.0, 3.0, 4.0, 5.0)) + d = merge(d, new_dict) + @test hasvalue(d, @varname(x[1][1][1][1])) + @test getvalue(d, @varname(x[1][1][1][1])) == expected_value + # for good measure + @test !hasvalue(d, @varname(x[1][1][1][2])) + @test !hasvalue(d, @varname(x[1][1][2][1])) + @test !hasvalue(d, @varname(x[1][2][1][1])) + @test !hasvalue(d, @varname(x[2][1][1][1])) + end + end +end + +@testset "with Distributions: getvalue + hasvalue" begin + using Distributions + using LinearAlgebra + + @testset "univariate" begin + d = Dict(@varname(x) => 1.0, @varname(y) => [[2.0]]) + @test hasvalue(d, @varname(x), Normal()) + @test getvalue(d, @varname(x), Normal()) == 1.0 + @test hasvalue(d, @varname(y[1][1]), Normal()) + @test getvalue(d, @varname(y[1][1]), Normal()) == 2.0 + end + + @testset "multivariate + matrix" begin + d = Dict(@varname(x[1]) => 1.0, @varname(x[2]) => 2.0) + @test hasvalue(d, @varname(x), MvNormal(zeros(1), I)) + @test getvalue(d, @varname(x), MvNormal(zeros(1), I)) == [1.0] + @test hasvalue(d, @varname(x), MvNormal(zeros(2), I)) + @test getvalue(d, @varname(x), MvNormal(zeros(2), I)) == [1.0, 2.0] + @test !hasvalue(d, @varname(x), MvNormal(zeros(3), I)) + @test_throws ErrorException hasvalue( + d, @varname(x), MvNormal(zeros(3), I); error_on_incomplete=true + ) + # If none of the varnames match, it should just return false instead of erroring + @test !hasvalue(d, @varname(y), MvNormal(zeros(2), I); error_on_incomplete=true) + end + + @testset "LKJCholesky :upside_down_smile:" begin + # yes, this isn't a valid Cholesky sample, but whatever + d = Dict( + @varname(x.L[1, 1]) => 1.0, + @varname(x.L[2, 1]) => 2.0, + @varname(x.L[2, 2]) => 3.0, + ) + @test hasvalue(d, @varname(x), LKJCholesky(2, 1.0)) + @test getvalue(d, @varname(x), LKJCholesky(2, 1.0)) == + Cholesky(LowerTriangular([1.0 0.0; 2.0 3.0])) + @test !hasvalue(d, @varname(x), LKJCholesky(3, 1.0)) + @test_throws ErrorException hasvalue( + d, @varname(x), LKJCholesky(3, 1.0); error_on_incomplete=true + ) + @test !hasvalue(d, @varname(y), LKJCholesky(3, 1.0); error_on_incomplete=true) + + d = Dict( + @varname(x.U[1, 1]) => 1.0, + @varname(x.U[1, 2]) => 2.0, + @varname(x.U[2, 2]) => 3.0, + ) + @test hasvalue(d, @varname(x), LKJCholesky(2, 1.0, :U)) + @test getvalue(d, @varname(x), LKJCholesky(2, 1.0, :U)) == + Cholesky(UpperTriangular([1.0 2.0; 0.0 3.0])) + @test !hasvalue(d, @varname(x), LKJCholesky(3, 1.0, :U)) + @test_throws ErrorException hasvalue( + d, @varname(x), LKJCholesky(3, 1.0, :U); error_on_incomplete=true + ) + @test !hasvalue(d, @varname(y), LKJCholesky(3, 1.0, :U); error_on_incomplete=true) + end +end diff --git a/test/runtests.jl b/test/runtests.jl index 8be65eb..b71a812 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -8,6 +8,7 @@ const GROUP = get(ENV, "GROUP", "All") if GROUP == "All" || GROUP == "Tests" include("varname.jl") include("abstractprobprog.jl") + include("hasvalue.jl") end if GROUP == "All" || GROUP == "Doctests" diff --git a/test/varname.jl b/test/varname.jl index 40260da..2fbff9b 100644 --- a/test/varname.jl +++ b/test/varname.jl @@ -252,6 +252,54 @@ end @test string_to_varname(varname_to_string(vn)) == vn end + @testset "head, tail, init, last" begin + @testset "specification" begin + @test AbstractPPL._head(@o _.a.b.c) == @o _.a + @test AbstractPPL._tail(@o _.a.b.c) == @o _.b.c + @test AbstractPPL._init(@o _.a.b.c) == @o _.a.b + @test AbstractPPL._last(@o _.a.b.c) == @o _.c + + @test AbstractPPL._head(@o _[1][2][3]) == @o _[1] + @test AbstractPPL._tail(@o _[1][2][3]) == @o _[2][3] + @test AbstractPPL._init(@o _[1][2][3]) == @o _[1][2] + @test AbstractPPL._last(@o _[1][2][3]) == @o _[3] + + @test AbstractPPL._head(@o _.a) == @o _.a + @test AbstractPPL._tail(@o _.a) == identity + @test AbstractPPL._init(@o _.a) == identity + @test AbstractPPL._last(@o _.a) == @o _.a + + @test AbstractPPL._head(@o _[1]) == @o _[1] + @test AbstractPPL._tail(@o _[1]) == identity + @test AbstractPPL._init(@o _[1]) == identity + @test AbstractPPL._last(@o _[1]) == @o _[1] + + @test AbstractPPL._head(identity) == identity + @test AbstractPPL._tail(identity) == identity + @test AbstractPPL._init(identity) == identity + @test AbstractPPL._last(identity) == identity + end + + @testset "composition" begin + varnames = ( + @varname(x), + @varname(x[1]), + @varname(x.a), + @varname(x.a.b), + @varname(x[1].a), + ) + for vn in varnames + optic = getoptic(vn) + @test AbstractPPL.normalise( + AbstractPPL._last(optic) ∘ AbstractPPL._init(optic) + ) == optic + @test AbstractPPL.normalise( + AbstractPPL._tail(optic) ∘ AbstractPPL._head(optic) + ) == optic + end + end + end + @testset "prefix and unprefix" begin @testset "basic cases" begin @test prefix(@varname(y), @varname(x)) == @varname(x.y)