From 56b9f2cdeee70b14d00a5fb22952b81a0f303ab3 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Sat, 5 Jul 2025 22:51:07 +0100 Subject: [PATCH 01/13] Move hasvalue and getvalue to AbstractPPL; reimplement --- src/AbstractPPL.jl | 5 +- src/hasvalue.jl | 271 +++++++++++++++++++++++++++++++++++++++++++++ test/hasvalue.jl | 62 +++++++++++ test/runtests.jl | 1 + 4 files changed, 338 insertions(+), 1 deletion(-) create mode 100644 src/hasvalue.jl create mode 100644 test/hasvalue.jl diff --git a/src/AbstractPPL.jl b/src/AbstractPPL.jl index 40da231..dd495db 100644 --- a/src/AbstractPPL.jl +++ b/src/AbstractPPL.jl @@ -16,7 +16,9 @@ export VarName, varname_to_string, string_to_varname, prefix, - unprefix + unprefix, + getvalue, + hasvalue # Abstract model functions export AbstractProbabilisticProgram, @@ -29,5 +31,6 @@ include("varname.jl") include("abstractmodeltrace.jl") include("abstractprobprog.jl") include("evaluate.jl") +include("hasvalue.jl") end # module diff --git a/src/hasvalue.jl b/src/hasvalue.jl new file mode 100644 index 0000000..c014db8 --- /dev/null +++ b/src/hasvalue.jl @@ -0,0 +1,271 @@ +""" + canview(optic, container) + +Return `true` if `optic` can be used to view `container`, and `false` otherwise. + +# Examples +```jldoctest; setup=:(using Accessors) +julia> AbstractPPL.canview(@o(_.a), (a = 1.0, )) +true + +julia> AbstractPPL.canview(@o(_.a), (b = 1.0, )) # property `a` does not exist +false + +julia> AbstractPPL.canview(@o(_.a[1]), (a = [1.0, 2.0], )) +true + +julia> AbstractPPL.canview(@o(_.a[3]), (a = [1.0, 2.0], )) # out of bounds +false +``` +""" +canview(optic, container) = false +canview(::typeof(identity), _) = true +function canview(::Accessors.PropertyLens{field}, x) where {field} + return hasproperty(x, field) +end + +# `IndexLens`: only relevant if `x` supports indexing. +canview(optic::Accessors.IndexLens, x) = false +function canview(optic::Accessors.IndexLens, x::AbstractArray) + return checkbounds(Bool, x, optic.indices...) +end + +# `ComposedFunction`: check that we can view `.inner` and `.outer`, but using +# value extracted using `.inner`. +function canview(optic::ComposedFunction, x) + return canview(optic.inner, x) && canview(optic.outer, optic.inner(x)) +end + +""" + getvalue(vals::NamedTuple, vn::VarName) + getvalue(vals::AbstractDict{<:VarName}, vn::VarName) + +Return the value(s) in `vals` represented by `vn`. + +# Examples + +For `NamedTuple`: + +```jldoctest +julia> vals = (x = [1.0],); + +julia> getvalue(vals, @varname(x)) # same as `getindex` +1-element Vector{Float64}: + 1.0 + +julia> getvalue(vals, @varname(x[1])) # different from `getindex` +1.0 + +julia> getvalue(vals, @varname(x[2])) +ERROR: BoundsError: attempt to access 1-element Vector{Float64} at index [2] +[...] +``` + +For `AbstractDict`: + +```jldoctest +julia> vals = Dict(@varname(x) => [1.0]); + +julia> getvalue(vals, @varname(x)) # same as `getindex` +1-element Vector{Float64}: + 1.0 + +julia> getvalue(vals, @varname(x[1])) # different from `getindex` +1.0 + +julia> getvalue(vals, @varname(x[2])) +ERROR: BoundsError: attempt to access 1-element Vector{Float64} at index [2] +[...] +``` + +In the `AbstractDict` case we can also have keys such as `v[1]`: + +```jldoctest +julia> vals = Dict(@varname(x[1]) => [1.0,]); + +julia> getvalue(vals, @varname(x[1])) # same as `getindex` +1-element Vector{Float64}: + 1.0 + +julia> getvalue(vals, @varname(x[1][1])) # different from `getindex` +1.0 + +julia> getvalue(vals, @varname(x[1][2])) +ERROR: BoundsError: attempt to access 1-element Vector{Float64} at index [2] +[...] + +julia> getvalue(vals, @varname(x[2][1])) +ERROR: KeyError: key x[2][1] not found +[...] +``` +""" +getvalue(vals::NamedTuple, vn::VarName) = get(vals, vn) +getvalue(vals::AbstractDict, vn::VarName) = nested_getindex(vals, vn) + +""" + hasvalue(vals::NamedTuple, vn::VarName) + hasvalue(vals::AbstractDict{<:VarName}, vn::VarName) + +Determine whether `vals` contains a value for a given `vn`. + +# Examples +With `x` as a `NamedTuple`: + +```jldoctest +julia> hasvalue((x = 1.0, ), @varname(x)) +true + +julia> hasvalue((x = 1.0, ), @varname(x[1])) +false + +julia> hasvalue((x = [1.0],), @varname(x)) +true + +julia> hasvalue((x = [1.0],), @varname(x[1])) +true + +julia> hasvalue((x = [1.0],), @varname(x[2])) +false +``` + +With `x` as a `AbstractDict`: + +```jldoctest +julia> hasvalue(Dict(@varname(x) => 1.0, ), @varname(x)) +true + +julia> hasvalue(Dict(@varname(x) => 1.0, ), @varname(x[1])) +false + +julia> hasvalue(Dict(@varname(x) => [1.0]), @varname(x)) +true + +julia> hasvalue(Dict(@varname(x) => [1.0]), @varname(x[1])) +true + +julia> hasvalue(Dict(@varname(x) => [1.0]), @varname(x[2])) +false +``` + +In the `AbstractDict` case we can also have keys such as `v[1]`: + +```jldoctest +julia> vals = Dict(@varname(x[1]) => [1.0,]); + +julia> hasvalue(vals, @varname(x[1])) # same as `haskey` +true + +julia> hasvalue(vals, @varname(x[1][1])) # different from `haskey` +true + +julia> hasvalue(vals, @varname(x[1][2])) +false + +julia> hasvalue(vals, @varname(x[2][1])) +false +``` +""" +function hasvalue(vals::NamedTuple, vn::VarName{sym}) where {sym} + return haskey(vals, sym) && canview(getoptic(vn), getproperty(vals, sym)) +end + +# For the Dict case, it is more complicated. There are two cases: +# 1. `vn` itself is already a key of `vals` (the easy case) +# 2. `vn` is not a key of `vals`, but some parent of `vn` is a key of `vals` +# (the harder case). For example, if `vn` is `x[1][2]`, then we need to +# check if either `x` or `x[1]` is a key of `vals`, and if so, whether +# we can index into the corresponding value. +function hasvalue(vals::AbstractDict{<:VarName}, vn::VarName{sym}) where {sym} + # First we check if `vn` is present as is. + haskey(vals, vn) && return true + + # Otherwise, we start by testing the bare `vn` (e.g., if `vn` is `x[1][2]`, + # we start by checking if `x` is present). We will then keep adding optics + # to `test_optic`, either until we find a key that is present, or until we + # run out of optics to test (which is determined by _inner(test_optic) == + # identity). + test_vn = VarName{sym}() + test_optic = getoptic(vn) + + while _inner(test_optic) != identity + @show test_vn, test_optic + if haskey(vals, test_vn) + @show canview(test_optic, vals[test_vn]) + end + if haskey(vals, test_vn) && canview(test_optic, vals[test_vn]) + return true + else + # Move the innermost optic into test_vn + test_optic_outer = _outer(test_optic) + test_optic_inner = _inner(test_optic) + test_vn = VarName{sym}(test_optic_inner ∘ getoptic(test_vn)) + test_optic = test_optic_outer + end + end + return false +end +# TODO(penelopeysm): Figure out tuple / namedtuple distributions, and LKJCholesky (grr) +# function hasvalue(vals::AbstractDict, vn::VarName, dist::Distribution) +# @warn "`hasvalue(vals, vn, dist)` is not implemented for $(typeof(dist)); falling back to `hasvalue(vals, vn)`." +# return hasvalue(vals, vn) +# end +# hasvalue(vals::AbstractDict, vn::VarName, ::UnivariateDistribution) = hasvalue(vals, vn) +# function hasvalue( +# vals::AbstractDict{<:VarName}, +# vn::VarName{sym}, +# dist::Union{MultivariateDistribution,MatrixDistribution}, +# ) where {sym} +# # If `vn` is present as-is, then we are good +# hasvalue(vals, vn) && return true +# # If not, then we need to check inside `vals` to see if a subset of +# # `vals` is enough to reconstruct `vn`. For example, if `vals` contains +# # `x[1]` and `x[2]`, and `dist` is `MvNormal(zeros(2), I)`, then we +# # can reconstruct `x`. If `dist` is `MvNormal(zeros(3), I)`, then we +# # can't. +# # To do this, we get the size of the distribution and iterate over all +# # possible indices. If every index can be found in `subsumed_keys`, then we +# # can return true. +# sz = size(dist) +# for idx in Iterators.product(map(Base.OneTo, sz)...) +# new_optic = if getoptic(vn) === identity +# Accessors.IndexLens(idx) +# else +# Accessors.IndexLens(idx) ∘ getoptic(vn) +# end +# new_vn = VarName{sym}(new_optic) +# hasvalue(vals, new_vn) || return false +# end +# return true +# end + +# """ +# nested_getindex(values::AbstractDict, vn::VarName) +# +# Return value corresponding to `vn` in `values` by also looking +# in the the actual values of the dict. +# """ +# function nested_getindex(values::AbstractDict, vn::VarName) +# maybeval = get(values, vn, nothing) +# if maybeval !== nothing +# return maybeval +# end +# +# # Split the optic into the key / `parent` and the extraction optic / `child`. +# parent, child, issuccess = splitoptic(getoptic(vn)) do optic +# o = optic === nothing ? identity : optic +# haskey(values, VarName(vn, o)) +# end +# # When combined with `VarInfo`, `nothing` is equivalent to `identity`. +# keyoptic = parent === nothing ? identity : parent +# +# # If we found a valid split, then we can extract the value. +# if !issuccess +# # At this point we just throw an error since the key could not be found. +# throw(KeyError(vn)) +# end +# +# # TODO: Should we also check that we `canview` the extracted `value` +# # rather than just let it fail upon `get` call? +# value = values[VarName(vn, keyoptic)] +# return child(value) +# end diff --git a/test/hasvalue.jl b/test/hasvalue.jl new file mode 100644 index 0000000..0cf3e95 --- /dev/null +++ b/test/hasvalue.jl @@ -0,0 +1,62 @@ +@testset "hasvalue" begin + @testset "NamedTuple" begin + nt = (a=[1], b=2, c=(x=3,), d=[1.0 0.5; 0.5 1.0]) + @test hasvalue(nt, @varname(a)) + @test hasvalue(nt, @varname(a[1])) + @test hasvalue(nt, @varname(b)) + @test hasvalue(nt, @varname(c)) + @test hasvalue(nt, @varname(c.x)) + @test hasvalue(nt, @varname(d)) + @test hasvalue(nt, @varname(d[1, 1])) + @test hasvalue(nt, @varname(d[1, 2])) + @test hasvalue(nt, @varname(d[2, 1])) + @test hasvalue(nt, @varname(d[2, 2])) + @test hasvalue(nt, @varname(d[3])) # linear indexing works.... + @test !hasvalue(nt, @varname(nope)) + @test !hasvalue(nt, @varname(a[2])) + @test !hasvalue(nt, @varname(a[1][1])) + @test !hasvalue(nt, @varname(c.x[1])) + @test !hasvalue(nt, @varname(c.y)) + @test !hasvalue(nt, @varname(d[1, 3])) + @test !hasvalue(nt, @varname(d[3, :])) + end + + @testset "Dict" begin + # same tests as above + d = Dict(@varname(a) => [1], + @varname(b) => 2, + @varname(c) => (x=3,), + @varname(d) => [1.0 0.5; 0.5 1.0]) + @test hasvalue(d, @varname(a)) + @test hasvalue(d, @varname(a[1])) + @test hasvalue(d, @varname(b)) + @test hasvalue(d, @varname(c)) + @test hasvalue(d, @varname(c.x)) + @test hasvalue(d, @varname(d)) + @test hasvalue(d, @varname(d[1, 1])) + @test hasvalue(d, @varname(d[1, 2])) + @test hasvalue(d, @varname(d[2, 1])) + @test hasvalue(d, @varname(d[2, 2])) + @test hasvalue(d, @varname(d[3])) # linear indexing works.... + @test !hasvalue(d, @varname(nope)) + @test !hasvalue(d, @varname(a[2])) + @test !hasvalue(d, @varname(a[1][1])) + @test !hasvalue(d, @varname(c.x[1])) + @test !hasvalue(d, @varname(c.y)) + @test !hasvalue(d, @varname(d[1, 3])) + @test !hasvalue(d, @varname(d[3])) + + # extra ones since Dict can have weird key + d = Dict(@varname(a[1]) => [1.0, 2.0], + @varname(b.x) => [3.0]) + @test hasvalue(d, @varname(a[1])) + @test hasvalue(d, @varname(a[1][1])) + @test hasvalue(d, @varname(a[1][2])) + @test hasvalue(d, @varname(b.x)) + @test hasvalue(d, @varname(b.x[1])) + @test !hasvalue(d, @varname(a)) + @test !hasvalue(d, @varname(a[2])) + @test !hasvalue(d, @varname(b.y)) + @test !hasvalue(d, @varname(b.x[2])) + end +end diff --git a/test/runtests.jl b/test/runtests.jl index 8be65eb..b71a812 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -8,6 +8,7 @@ const GROUP = get(ENV, "GROUP", "All") if GROUP == "All" || GROUP == "Tests" include("varname.jl") include("abstractprobprog.jl") + include("hasvalue.jl") end if GROUP == "All" || GROUP == "Doctests" From 646b9d723f2d64ec48b0989ca355d8e0bd7a1844 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Sat, 5 Jul 2025 23:13:31 +0100 Subject: [PATCH 02/13] Add hasvalue for (some) distributions --- Project.toml | 7 ++ ext/AbstractPPLDistributionsExt.jl | 51 ++++++++++++ src/hasvalue.jl | 127 ++++++++++------------------- test/Project.toml | 2 + test/hasvalue.jl | 64 +++++++++++++-- 5 files changed, 160 insertions(+), 91 deletions(-) create mode 100644 ext/AbstractPPLDistributionsExt.jl diff --git a/Project.toml b/Project.toml index ecccd56..f9213f3 100644 --- a/Project.toml +++ b/Project.toml @@ -13,10 +13,17 @@ JSON = "682c06a0-de6a-54ab-a142-c8b1cf79cde6" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" +[weakdeps] +Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" + +[extensions] +AbstractPPLDistributionsExt = ["Distributions"] + [compat] AbstractMCMC = "2, 3, 4, 5" Accessors = "0.1" DensityInterface = "0.4" +Distributions = "0.25" JSON = "0.19 - 0.21" Random = "1.6" StatsBase = "0.32, 0.33, 0.34" diff --git a/ext/AbstractPPLDistributionsExt.jl b/ext/AbstractPPLDistributionsExt.jl new file mode 100644 index 0000000..fe767eb --- /dev/null +++ b/ext/AbstractPPLDistributionsExt.jl @@ -0,0 +1,51 @@ +module AbstractPPLDistributionsExt + +if isdefined(Base, :get_extension) + using AbstractPPL: AbstractPPL, VarName, Accessors + using Distributions: Distributions +else + using ..AbstractPPL: AbstractPPL, VarName, Accessors + using ..Distributions: Distributions +end + +# TODO(penelopeysm): Figure out tuple / namedtuple distributions, and LKJCholesky (grr) +function AbstractPPL.hasvalue( + vals::AbstractDict, vn::VarName, dist::Distributions.Distribution +) + @warn "`hasvalue(vals, vn, dist)` is not implemented for $(typeof(dist)); falling back to `hasvalue(vals, vn)`." + return AbstractPPL.hasvalue(vals, vn) +end +function AbstractPPL.hasvalue( + vals::AbstractDict, vn::VarName, ::Distributions.UnivariateDistribution +) + return AbstractPPL.hasvalue(vals, vn) +end +function AbstractPPL.hasvalue( + vals::AbstractDict{<:VarName}, + vn::VarName{sym}, + dist::Union{Distributions.MultivariateDistribution,Distributions.MatrixDistribution}, +) where {sym} + # If `vn` is present as-is, then we are good + AbstractPPL.hasvalue(vals, vn) && return true + # If not, then we need to check inside `vals` to see if a subset of + # `vals` is enough to reconstruct `vn`. For example, if `vals` contains + # `x[1]` and `x[2]`, and `dist` is `MvNormal(zeros(2), I)`, then we + # can reconstruct `x`. If `dist` is `MvNormal(zeros(3), I)`, then we + # can't. + # To do this, we get the size of the distribution and iterate over all + # possible indices. If every index can be found in `subsumed_keys`, then we + # can return true. + sz = size(dist) + for idx in Iterators.product(map(Base.OneTo, sz)...) + new_optic = if AbstractPPL.getoptic(vn) === identity + Accessors.IndexLens(idx) + else + Accessors.IndexLens(idx) ∘ AbstractPPL.getoptic(vn) + end + new_vn = VarName{sym}(new_optic) + AbstractPPL.hasvalue(vals, new_vn) || return false + end + return true +end + +end diff --git a/src/hasvalue.jl b/src/hasvalue.jl index c014db8..a74b76c 100644 --- a/src/hasvalue.jl +++ b/src/hasvalue.jl @@ -57,7 +57,7 @@ julia> getvalue(vals, @varname(x[1])) # different from `getindex` 1.0 julia> getvalue(vals, @varname(x[2])) -ERROR: BoundsError: attempt to access 1-element Vector{Float64} at index [2] +ERROR: getvalue: x[2] was not found in the values provided [...] ``` @@ -74,7 +74,7 @@ julia> getvalue(vals, @varname(x[1])) # different from `getindex` 1.0 julia> getvalue(vals, @varname(x[2])) -ERROR: BoundsError: attempt to access 1-element Vector{Float64} at index [2] +ERROR: getvalue: x[2] was not found in the values provided [...] ``` @@ -91,16 +91,53 @@ julia> getvalue(vals, @varname(x[1][1])) # different from `getindex` 1.0 julia> getvalue(vals, @varname(x[1][2])) -ERROR: BoundsError: attempt to access 1-element Vector{Float64} at index [2] +ERROR: getvalue: x[1][2] was not found in the values provided [...] julia> getvalue(vals, @varname(x[2][1])) -ERROR: KeyError: key x[2][1] not found +ERROR: getvalue: x[2][1] was not found in the values provided [...] ``` """ -getvalue(vals::NamedTuple, vn::VarName) = get(vals, vn) -getvalue(vals::AbstractDict, vn::VarName) = nested_getindex(vals, vn) +function getvalue(vals::NamedTuple, vn::VarName{sym}) where {sym} + optic = getoptic(vn) + if haskey(vals, sym) && canview(optic, getproperty(vals, sym)) + return optic(vals[sym]) + else + error("getvalue: $(vn) was not found in the values provided") + end +end +# For the Dict case, it is more complicated. There are two cases: +# 1. `vn` itself is already a key of `vals` (the easy case) +# 2. `vn` is not a key of `vals`, but some parent of `vn` is a key of `vals` +# (the harder case). For example, if `vn` is `x[1][2]`, then we need to +# check if either `x` or `x[1]` is a key of `vals`, and if so, whether +# we can index into the corresponding value. +function getvalue(vals::AbstractDict{<:VarName}, vn::VarName{sym}) where {sym} + # First we check if `vn` is present as is. + haskey(vals, vn) && return vals[vn] + + # Otherwise, we start by testing the bare `vn` (e.g., if `vn` is `x[1][2]`, + # we start by checking if `x` is present). We will then keep adding optics + # to `test_optic`, either until we find a key that is present, or until we + # run out of optics to test (which is determined by _inner(test_optic) == + # identity). + test_vn = VarName{sym}() + test_optic = getoptic(vn) + + while _inner(test_optic) != identity + if haskey(vals, test_vn) && canview(test_optic, vals[test_vn]) + return test_optic(vals[test_vn]) + else + # Move the innermost optic into test_vn + test_optic_outer = _outer(test_optic) + test_optic_inner = _inner(test_optic) + test_vn = VarName{sym}(test_optic_inner ∘ getoptic(test_vn)) + test_optic = test_optic_outer + end + end + return error("getvalue: $(vn) was not found in the values provided") +end """ hasvalue(vals::NamedTuple, vn::VarName) @@ -168,13 +205,6 @@ false function hasvalue(vals::NamedTuple, vn::VarName{sym}) where {sym} return haskey(vals, sym) && canview(getoptic(vn), getproperty(vals, sym)) end - -# For the Dict case, it is more complicated. There are two cases: -# 1. `vn` itself is already a key of `vals` (the easy case) -# 2. `vn` is not a key of `vals`, but some parent of `vn` is a key of `vals` -# (the harder case). For example, if `vn` is `x[1][2]`, then we need to -# check if either `x` or `x[1]` is a key of `vals`, and if so, whether -# we can index into the corresponding value. function hasvalue(vals::AbstractDict{<:VarName}, vn::VarName{sym}) where {sym} # First we check if `vn` is present as is. haskey(vals, vn) && return true @@ -186,12 +216,8 @@ function hasvalue(vals::AbstractDict{<:VarName}, vn::VarName{sym}) where {sym} # identity). test_vn = VarName{sym}() test_optic = getoptic(vn) - + while _inner(test_optic) != identity - @show test_vn, test_optic - if haskey(vals, test_vn) - @show canview(test_optic, vals[test_vn]) - end if haskey(vals, test_vn) && canview(test_optic, vals[test_vn]) return true else @@ -204,68 +230,3 @@ function hasvalue(vals::AbstractDict{<:VarName}, vn::VarName{sym}) where {sym} end return false end -# TODO(penelopeysm): Figure out tuple / namedtuple distributions, and LKJCholesky (grr) -# function hasvalue(vals::AbstractDict, vn::VarName, dist::Distribution) -# @warn "`hasvalue(vals, vn, dist)` is not implemented for $(typeof(dist)); falling back to `hasvalue(vals, vn)`." -# return hasvalue(vals, vn) -# end -# hasvalue(vals::AbstractDict, vn::VarName, ::UnivariateDistribution) = hasvalue(vals, vn) -# function hasvalue( -# vals::AbstractDict{<:VarName}, -# vn::VarName{sym}, -# dist::Union{MultivariateDistribution,MatrixDistribution}, -# ) where {sym} -# # If `vn` is present as-is, then we are good -# hasvalue(vals, vn) && return true -# # If not, then we need to check inside `vals` to see if a subset of -# # `vals` is enough to reconstruct `vn`. For example, if `vals` contains -# # `x[1]` and `x[2]`, and `dist` is `MvNormal(zeros(2), I)`, then we -# # can reconstruct `x`. If `dist` is `MvNormal(zeros(3), I)`, then we -# # can't. -# # To do this, we get the size of the distribution and iterate over all -# # possible indices. If every index can be found in `subsumed_keys`, then we -# # can return true. -# sz = size(dist) -# for idx in Iterators.product(map(Base.OneTo, sz)...) -# new_optic = if getoptic(vn) === identity -# Accessors.IndexLens(idx) -# else -# Accessors.IndexLens(idx) ∘ getoptic(vn) -# end -# new_vn = VarName{sym}(new_optic) -# hasvalue(vals, new_vn) || return false -# end -# return true -# end - -# """ -# nested_getindex(values::AbstractDict, vn::VarName) -# -# Return value corresponding to `vn` in `values` by also looking -# in the the actual values of the dict. -# """ -# function nested_getindex(values::AbstractDict, vn::VarName) -# maybeval = get(values, vn, nothing) -# if maybeval !== nothing -# return maybeval -# end -# -# # Split the optic into the key / `parent` and the extraction optic / `child`. -# parent, child, issuccess = splitoptic(getoptic(vn)) do optic -# o = optic === nothing ? identity : optic -# haskey(values, VarName(vn, o)) -# end -# # When combined with `VarInfo`, `nothing` is equivalent to `identity`. -# keyoptic = parent === nothing ? identity : parent -# -# # If we found a valid split, then we can extract the value. -# if !issuccess -# # At this point we just throw an error since the key could not be found. -# throw(KeyError(vn)) -# end -# -# # TODO: Should we also check that we `canview` the extracted `value` -# # rather than just let it fail upon `get` call? -# value = values[VarName(vn, keyoptic)] -# return child(value) -# end diff --git a/test/Project.toml b/test/Project.toml index dbc641b..23fb280 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,7 +1,9 @@ [deps] Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" +Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" InvertedIndices = "41ab1584-1d38-5bbf-9106-f11c6c58b48f" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" diff --git a/test/hasvalue.jl b/test/hasvalue.jl index 0cf3e95..b6e3996 100644 --- a/test/hasvalue.jl +++ b/test/hasvalue.jl @@ -1,17 +1,28 @@ -@testset "hasvalue" begin +@testset "base getvalue + hasvalue" begin @testset "NamedTuple" begin nt = (a=[1], b=2, c=(x=3,), d=[1.0 0.5; 0.5 1.0]) @test hasvalue(nt, @varname(a)) + @test getvalue(nt, @varname(a)) == [1] @test hasvalue(nt, @varname(a[1])) + @test getvalue(nt, @varname(a[1])) == 1 @test hasvalue(nt, @varname(b)) + @test getvalue(nt, @varname(b)) == 2 @test hasvalue(nt, @varname(c)) + @test getvalue(nt, @varname(c)) == (x=3,) @test hasvalue(nt, @varname(c.x)) + @test getvalue(nt, @varname(c.x)) == 3 @test hasvalue(nt, @varname(d)) + @test getvalue(nt, @varname(d)) == [1.0 0.5; 0.5 1.0] @test hasvalue(nt, @varname(d[1, 1])) + @test getvalue(nt, @varname(d[1, 1])) == 1.0 @test hasvalue(nt, @varname(d[1, 2])) + @test getvalue(nt, @varname(d[1, 2])) == 0.5 @test hasvalue(nt, @varname(d[2, 1])) + @test getvalue(nt, @varname(d[2, 1])) == 0.5 @test hasvalue(nt, @varname(d[2, 2])) + @test getvalue(nt, @varname(d[2, 2])) == 1.0 @test hasvalue(nt, @varname(d[3])) # linear indexing works.... + @test getvalue(nt, @varname(d[3])) == 0.5 @test !hasvalue(nt, @varname(nope)) @test !hasvalue(nt, @varname(a[2])) @test !hasvalue(nt, @varname(a[1][1])) @@ -23,40 +34,77 @@ @testset "Dict" begin # same tests as above - d = Dict(@varname(a) => [1], + d = Dict( + @varname(a) => [1], @varname(b) => 2, @varname(c) => (x=3,), - @varname(d) => [1.0 0.5; 0.5 1.0]) + @varname(d) => [1.0 0.5; 0.5 1.0], + ) @test hasvalue(d, @varname(a)) + @test getvalue(d, @varname(a)) == [1] @test hasvalue(d, @varname(a[1])) + @test getvalue(d, @varname(a[1])) == 1 @test hasvalue(d, @varname(b)) + @test getvalue(d, @varname(b)) == 2 @test hasvalue(d, @varname(c)) + @test getvalue(d, @varname(c)) == (x=3,) @test hasvalue(d, @varname(c.x)) + @test getvalue(d, @varname(c.x)) == 3 @test hasvalue(d, @varname(d)) + @test getvalue(d, @varname(d)) == [1.0 0.5; 0.5 1.0] @test hasvalue(d, @varname(d[1, 1])) + @test getvalue(d, @varname(d[1, 1])) == 1.0 @test hasvalue(d, @varname(d[1, 2])) + @test getvalue(d, @varname(d[1, 2])) == 0.5 @test hasvalue(d, @varname(d[2, 1])) + @test getvalue(d, @varname(d[2, 1])) == 0.5 @test hasvalue(d, @varname(d[2, 2])) - @test hasvalue(d, @varname(d[3])) # linear indexing works.... + @test getvalue(d, @varname(d[2, 2])) == 1.0 + @test hasvalue(d, @varname(d[3])) # linear indexing works.... + @test getvalue(d, @varname(d[3])) == 0.5 @test !hasvalue(d, @varname(nope)) @test !hasvalue(d, @varname(a[2])) @test !hasvalue(d, @varname(a[1][1])) @test !hasvalue(d, @varname(c.x[1])) @test !hasvalue(d, @varname(c.y)) @test !hasvalue(d, @varname(d[1, 3])) - @test !hasvalue(d, @varname(d[3])) - # extra ones since Dict can have weird key - d = Dict(@varname(a[1]) => [1.0, 2.0], - @varname(b.x) => [3.0]) + # extra ones since Dict can have weird keys + d = Dict( + @varname(a[1]) => [1.0, 2.0], + @varname(b.x) => [3.0], + @varname(c[2]) => (a=4.0, b=5.0), + ) @test hasvalue(d, @varname(a[1])) + @test getvalue(d, @varname(a[1])) == [1.0, 2.0] @test hasvalue(d, @varname(a[1][1])) + @test getvalue(d, @varname(a[1][1])) == 1.0 @test hasvalue(d, @varname(a[1][2])) + @test getvalue(d, @varname(a[1][2])) == 2.0 @test hasvalue(d, @varname(b.x)) + @test getvalue(d, @varname(b.x)) == [3.0] @test hasvalue(d, @varname(b.x[1])) + @test getvalue(d, @varname(b.x[1])) == 3.0 + @test hasvalue(d, @varname(c[2])) + @test getvalue(d, @varname(c[2])) == (a=4.0, b=5.0) + @test hasvalue(d, @varname(c[2].a)) + @test getvalue(d, @varname(c[2].a)) == 4.0 + @test hasvalue(d, @varname(c[2].b)) + @test getvalue(d, @varname(c[2].b)) == 5.0 @test !hasvalue(d, @varname(a)) @test !hasvalue(d, @varname(a[2])) @test !hasvalue(d, @varname(b.y)) @test !hasvalue(d, @varname(b.x[2])) + @test !hasvalue(d, @varname(c[1])) + @test !hasvalue(d, @varname(c[2].x)) end end + +@testset "with Distributions: getvalue + hasvalue" begin + using Distributions + using LinearAlgebra + + d = Dict(@varname(x[1]) => 1.0, @varname(x[2]) => 2.0) + @test hasvalue(d, @varname(x), MvNormal(zeros(2), I)) + @test !hasvalue(d, @varname(x), MvNormal(zeros(3), I)) +end From 07975a24241458081984d16ea83bc3f2971d7826 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Sun, 6 Jul 2025 00:09:45 +0100 Subject: [PATCH 03/13] Bump min Julia to 1.10 --- Project.toml | 2 +- ext/AbstractPPLDistributionsExt.jl | 9 ++------- 2 files changed, 3 insertions(+), 8 deletions(-) diff --git a/Project.toml b/Project.toml index f9213f3..1ca410d 100644 --- a/Project.toml +++ b/Project.toml @@ -27,4 +27,4 @@ Distributions = "0.25" JSON = "0.19 - 0.21" Random = "1.6" StatsBase = "0.32, 0.33, 0.34" -julia = "~1.6.6, 1.7.3" +julia = "1.10" diff --git a/ext/AbstractPPLDistributionsExt.jl b/ext/AbstractPPLDistributionsExt.jl index fe767eb..fd6b979 100644 --- a/ext/AbstractPPLDistributionsExt.jl +++ b/ext/AbstractPPLDistributionsExt.jl @@ -1,12 +1,7 @@ module AbstractPPLDistributionsExt -if isdefined(Base, :get_extension) - using AbstractPPL: AbstractPPL, VarName, Accessors - using Distributions: Distributions -else - using ..AbstractPPL: AbstractPPL, VarName, Accessors - using ..Distributions: Distributions -end +using AbstractPPL: AbstractPPL, VarName, Accessors +using Distributions: Distributions # TODO(penelopeysm): Figure out tuple / namedtuple distributions, and LKJCholesky (grr) function AbstractPPL.hasvalue( From 332c64a7236b9cdca93572697610a2cca3de778a Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Sun, 6 Jul 2025 13:25:48 +0100 Subject: [PATCH 04/13] Make hasvalue and getvalue use the most specific value --- HISTORY.md | 13 +++++ Project.toml | 2 +- src/hasvalue.jl | 59 ++++++++++--------- src/varname.jl | 146 +++++++++++++++++++++++++++++++++-------------- test/hasvalue.jl | 30 ++++++++-- 5 files changed, 175 insertions(+), 75 deletions(-) diff --git a/HISTORY.md b/HISTORY.md index 65a678f..551d94a 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -1,3 +1,16 @@ +## 0.12.1 + +Minimum compatibility has been bumped to Julia 1.10. + +Added the new functions `hasvalue(container::T, ::VarName[, ::Distribution])` and `getvalue(container::T, ::VarName[, ::Distribution])`, where `T` is either `NamedTuple` or `AbstractDict{<:VarName}`. + +These functions check whether a given `VarName` has a value in the given `NamedTuple` or `AbstractDict`, and return the value if it exists. + +The optional `Distribution` argument allows one to reconstruct a full value from its component indices. +For example, if `container` has `x[1]` and `x[2]`, then `hasvalue(container, @varname(x), dist)` will return true if `size(dist) == (2,)` (for example, `MvNormal(zeros(2), I)`). + +These functions (without the `Distribution` argument) were previously in DynamicPPL.jl (albeit unexported). + ## 0.12.0 ### VarName constructors diff --git a/Project.toml b/Project.toml index 1ca410d..3adbf5e 100644 --- a/Project.toml +++ b/Project.toml @@ -3,7 +3,7 @@ uuid = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf" keywords = ["probablistic programming"] license = "MIT" desc = "Common interfaces for probabilistic programming" -version = "0.12.0" +version = "0.12.1" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" diff --git a/src/hasvalue.jl b/src/hasvalue.jl index a74b76c..65319dd 100644 --- a/src/hasvalue.jl +++ b/src/hasvalue.jl @@ -117,26 +117,27 @@ function getvalue(vals::AbstractDict{<:VarName}, vn::VarName{sym}) where {sym} # First we check if `vn` is present as is. haskey(vals, vn) && return vals[vn] - # Otherwise, we start by testing the bare `vn` (e.g., if `vn` is `x[1][2]`, - # we start by checking if `x` is present). We will then keep adding optics - # to `test_optic`, either until we find a key that is present, or until we - # run out of optics to test (which is determined by _inner(test_optic) == - # identity). - test_vn = VarName{sym}() - test_optic = getoptic(vn) - - while _inner(test_optic) != identity + # Otherwise, we start by testing the `vn` one level up (e.g., if `vn` is + # `x[1][2]`, we start by checking if `x[1]` is present, then `x`). We will + # then keep removing optics from `test_optic`, either until we find a key + # that is present, or until we run out of optics to test (which happens + # after getoptic(test_vn) == identity). + o = getoptic(vn) + test_vn = VarName{sym}(_init(o)) + test_optic = _last(o) + + while true if haskey(vals, test_vn) && canview(test_optic, vals[test_vn]) return test_optic(vals[test_vn]) else - # Move the innermost optic into test_vn - test_optic_outer = _outer(test_optic) - test_optic_inner = _inner(test_optic) - test_vn = VarName{sym}(test_optic_inner ∘ getoptic(test_vn)) - test_optic = test_optic_outer + # Try to move the outermost optic from test_vn into test_optic. + # If test_vn is already an identity, we can't, so we stop. + o = getoptic(test_vn) + o == identity && error("getvalue: $(vn) was not found in the values provided") + test_vn = VarName{sym}(_init(o)) + test_optic = normalise(_last(o) ∘ test_optic) end end - return error("getvalue: $(vn) was not found in the values provided") end """ @@ -209,23 +210,25 @@ function hasvalue(vals::AbstractDict{<:VarName}, vn::VarName{sym}) where {sym} # First we check if `vn` is present as is. haskey(vals, vn) && return true - # Otherwise, we start by testing the bare `vn` (e.g., if `vn` is `x[1][2]`, - # we start by checking if `x` is present). We will then keep adding optics - # to `test_optic`, either until we find a key that is present, or until we - # run out of optics to test (which is determined by _inner(test_optic) == - # identity). - test_vn = VarName{sym}() - test_optic = getoptic(vn) + # Otherwise, we start by testing the `vn` one level up (e.g., if `vn` is + # `x[1][2]`, we start by checking if `x[1]` is present, then `x`). We will + # then keep removing optics from `test_optic`, either until we find a key + # that is present, or until we run out of optics to test (which happens + # after getoptic(test_vn) == identity). + o = getoptic(vn) + test_vn = VarName{sym}(_init(o)) + test_optic = _last(o) - while _inner(test_optic) != identity + while true if haskey(vals, test_vn) && canview(test_optic, vals[test_vn]) return true else - # Move the innermost optic into test_vn - test_optic_outer = _outer(test_optic) - test_optic_inner = _inner(test_optic) - test_vn = VarName{sym}(test_optic_inner ∘ getoptic(test_vn)) - test_optic = test_optic_outer + # Try to move the outermost optic from test_vn into test_optic. + # If test_vn is already an identity, we can't, so we stop. + o = getoptic(test_vn) + o == identity && return false + test_vn = VarName{sym}(_init(o)) + test_optic = normalise(_last(o) ∘ test_optic) end end return false diff --git a/src/varname.jl b/src/varname.jl index 83f6f6f..18cfc06 100644 --- a/src/varname.jl +++ b/src/varname.jl @@ -963,69 +963,131 @@ string_to_varname(str::AbstractString) = dict_to_varname(JSON.parse(str)) ### Prefixing and unprefixing """ - _strip_identity(optic) + _head(optic) -Remove identity lenses from composed optics. -""" -_strip_identity(::Base.ComposedFunction{typeof(identity),typeof(identity)}) = identity -function _strip_identity(o::Base.ComposedFunction{Outer,typeof(identity)}) where {Outer} - return _strip_identity(o.outer) -end -function _strip_identity(o::Base.ComposedFunction{typeof(identity),Inner}) where {Inner} - return _strip_identity(o.inner) -end -_strip_identity(o::Base.ComposedFunction) = o -_strip_identity(o::Accessors.PropertyLens) = o -_strip_identity(o::Accessors.IndexLens) = o -_strip_identity(o::typeof(identity)) = o - -""" - _inner(optic) +Get the innermost layer of an optic. -Get the innermost (non-identity) layer of an optic. +!!! note + Does not perform optic normalisation. You may wish to call + `normalise(optic)` before using this function if the optic you are passing + was not obtained from a VarName. ```jldoctest; setup=:(using Accessors) -julia> AbstractPPL._inner(Accessors.@o _.a.b.c) +julia> AbstractPPL._head(Accessors.@o _.a.b.c) (@o _.a) -julia> AbstractPPL._inner(Accessors.@o _[1][2][3]) +julia> AbstractPPL._head(Accessors.@o _[1][2][3]) (@o _[1]) -julia> AbstractPPL._inner(Accessors.@o _) +julia> AbstractPPL._head(Accessors.@o _.a) +(@o _.a) + +julia> AbstractPPL._head(Accessors.@o _[1]) +(@o _[1]) + +julia> AbstractPPL._head(Accessors.@o _) identity (generic function with 1 method) ``` """ -_inner(o::Base.ComposedFunction{Outer,Inner}) where {Outer,Inner} = o.inner -_inner(o::Accessors.PropertyLens) = o -_inner(o::Accessors.IndexLens) = o -_inner(o::typeof(identity)) = o +_head(o::ComposedFunction{Outer,Inner}) where {Outer,Inner} = o.inner +_head(o::Accessors.PropertyLens) = o +_head(o::Accessors.IndexLens) = o +_head(::typeof(identity)) = identity """ - _outer(optic) + _tail(optic) + +Get everything but the innermost layer of an optic. -Get the outer layer of an optic. +!!! note + Does not perform optic normalisation. You may wish to call + `normalise(optic)` before using this function if the optic you are passing + was not obtained from a VarName. ```jldoctest; setup=:(using Accessors) -julia> AbstractPPL._outer(Accessors.@o _.a.b.c) +julia> AbstractPPL._tail(Accessors.@o _.a.b.c) (@o _.b.c) -julia> AbstractPPL._outer(Accessors.@o _[1][2][3]) +julia> AbstractPPL._tail(Accessors.@o _[1][2][3]) (@o _[2][3]) -julia> AbstractPPL._outer(Accessors.@o _.a) +julia> AbstractPPL._tail(Accessors.@o _.a) identity (generic function with 1 method) -julia> AbstractPPL._outer(Accessors.@o _[1]) +julia> AbstractPPL._tail(Accessors.@o _[1]) identity (generic function with 1 method) -julia> AbstractPPL._outer(Accessors.@o _) +julia> AbstractPPL._tail(Accessors.@o _) identity (generic function with 1 method) ``` """ -_outer(o::Base.ComposedFunction{Outer,Inner}) where {Outer,Inner} = o.outer -_outer(::Accessors.PropertyLens) = identity -_outer(::Accessors.IndexLens) = identity -_outer(::typeof(identity)) = identity +_tail(o::ComposedFunction{Outer,Inner}) where {Outer,Inner} = o.outer +_tail(::Accessors.PropertyLens) = identity +_tail(::Accessors.IndexLens) = identity +_tail(::typeof(identity)) = identity + +""" + _last(optic) + +Get the outermost layer of an optic. + +!!! note + Does not perform optic normalisation. You may wish to call + `normalise(optic)` before using this function if the optic you are passing + was not obtained from a VarName. + +```jldoctest; setup=:(using Accessors) +julia> AbstractPPL._last(Accessors.@o _.a.b.c) +(@o _.c) + +julia> AbstractPPL._last(Accessors.@o _[1][2][3]) +(@o _[3]) + +julia> AbstractPPL._last(Accessors.@o _.a) +(@o _.a) + +julia> AbstractPPL._last(Accessors.@o _[1]) +(@o _[1]) + +julia> AbstractPPL._last(Accessors.@o _) +identity (generic function with 1 method) +``` +""" +_last(o::ComposedFunction{Outer,Inner}) where {Outer,Inner} = _last(o.outer) +_last(o::Accessors.PropertyLens) = o +_last(o::Accessors.IndexLens) = o +_last(::typeof(identity)) = identity + +""" + _init(optic) + +Get everything but the outermost layer of an optic. + +!!! note + Does not perform optic normalisation. You may wish to call + `normalise(optic)` before using this function if the optic you are passing + was not obtained from a VarName. + +```jldoctest; setup=:(using Accessors) +julia> AbstractPPL._init(Accessors.@o _.a.b.c) +(@o _.a.b) + +julia> AbstractPPL._init(Accessors.@o _[1][2][3]) +(@o _[1][2]) + +julia> AbstractPPL._init(Accessors.@o _.a) +identity (generic function with 1 method) + +julia> AbstractPPL._init(Accessors.@o _[1]) +identity (generic function with 1 method) + +julia> AbstractPPL._init(Accessors.@o _) +identity (generic function with 1 method) +""" +_init(o::ComposedFunction{Outer,Inner}) where {Outer,Inner} = _init(o.outer) ∘ o.inner +_init(::Accessors.PropertyLens) = identity +_init(::Accessors.IndexLens) = identity +_init(::typeof(identity)) = identity """ optic_to_vn(optic) @@ -1058,11 +1120,11 @@ function optic_to_vn(::Accessors.PropertyLens{sym}) where {sym} return VarName{sym}() end function optic_to_vn( - o::Base.ComposedFunction{Outer,Accessors.PropertyLens{sym}} + o::ComposedFunction{Outer,Accessors.PropertyLens{sym}} ) where {Outer,sym} return VarName{sym}(o.outer) end -optic_to_vn(o::Base.ComposedFunction) = optic_to_vn(normalise(o)) +optic_to_vn(o::ComposedFunction) = optic_to_vn(normalise(o)) function optic_to_vn(@nospecialize(o)) msg = "optic_to_vn: could not convert optic `$o` to a VarName" throw(ArgumentError(msg)) @@ -1077,14 +1139,14 @@ function unprefix_optic(optic, optic_prefix) optic = normalise(optic) optic_prefix = normalise(optic_prefix) # strip one layer of the optic and check for equality - inner = _inner(optic) - inner_prefix = _inner(optic_prefix) - if inner != inner_prefix + head = _head(optic) + head_prefix = _head(optic_prefix) + if head != head_prefix msg = "could not remove prefix $(optic_prefix) from optic $(optic)" throw(ArgumentError(msg)) end # recurse - return unprefix_optic(_outer(optic), _outer(optic_prefix)) + return unprefix_optic(_tail(optic), _tail(optic_prefix)) end """ diff --git a/test/hasvalue.jl b/test/hasvalue.jl index b6e3996..62bdf7d 100644 --- a/test/hasvalue.jl +++ b/test/hasvalue.jl @@ -1,5 +1,5 @@ @testset "base getvalue + hasvalue" begin - @testset "NamedTuple" begin + @testset "basic NamedTuple" begin nt = (a=[1], b=2, c=(x=3,), d=[1.0 0.5; 0.5 1.0]) @test hasvalue(nt, @varname(a)) @test getvalue(nt, @varname(a)) == [1] @@ -32,8 +32,8 @@ @test !hasvalue(nt, @varname(d[3, :])) end - @testset "Dict" begin - # same tests as above + @testset "basic Dict" begin + # same tests as for NamedTuple d = Dict( @varname(a) => [1], @varname(b) => 2, @@ -68,8 +68,9 @@ @test !hasvalue(d, @varname(c.x[1])) @test !hasvalue(d, @varname(c.y)) @test !hasvalue(d, @varname(d[1, 3])) + end - # extra ones since Dict can have weird keys + @testset "Dict with non-identity varname keys" begin d = Dict( @varname(a[1]) => [1.0, 2.0], @varname(b.x) => [3.0], @@ -98,6 +99,27 @@ @test !hasvalue(d, @varname(c[1])) @test !hasvalue(d, @varname(c[2].x)) end + + @testset "Dict with redundancy" begin + d1 = Dict(@varname(x) => [[[[1.0]]]]) + d2 = Dict(@varname(x[1]) => [[[2.0]]]) + d3 = Dict(@varname(x[1][1]) => [[3.0]]) + d4 = Dict(@varname(x[1][1][1]) => [4.0]) + d5 = Dict(@varname(x[1][1][1][1]) => 5.0) + + d = Dict{VarName,Any}() + for (new_dict, expected_value) in + zip((d1, d2, d3, d4, d5), (1.0, 2.0, 3.0, 4.0, 5.0)) + d = merge(d, new_dict) + @test hasvalue(d, @varname(x[1][1][1][1])) + @test getvalue(d, @varname(x[1][1][1][1])) == expected_value + # for good measure + @test !hasvalue(d, @varname(x[1][1][1][2])) + @test !hasvalue(d, @varname(x[1][1][2][1])) + @test !hasvalue(d, @varname(x[1][2][1][1])) + @test !hasvalue(d, @varname(x[2][1][1][1])) + end + end end @testset "with Distributions: getvalue + hasvalue" begin From 049001e7606e711121360cddaa0b60419a06e79b Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Sun, 6 Jul 2025 13:33:37 +0100 Subject: [PATCH 05/13] Specify getvalue semantics in docstring --- src/hasvalue.jl | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/src/hasvalue.jl b/src/hasvalue.jl index 65319dd..1e44c37 100644 --- a/src/hasvalue.jl +++ b/src/hasvalue.jl @@ -98,6 +98,33 @@ julia> getvalue(vals, @varname(x[2][1])) ERROR: getvalue: x[2][1] was not found in the values provided [...] ``` + +Dictionaries can present ambiguous cases where the same variable is specified +twice at different levels. In such a situation, `getvalue` attempts to find an +exact match, and if that fails it returns the value with the most specific key. + +!!! note + It is the user's responsibility to avoid such cases by ensuring that the + dictionary passed in does not contain the same value specified multiple + times. + +```jldoctest +julia> vals = Dict(@varname(x) => [[1.0]], @varname(x[1]) => [2.0]); + +julia> # Here, the `x[1]` key is not used because `x` is an exact match. + getvalue(vals, @varname(x)) +1-element Vector{Vector{Float64}}: + [1.0] + +julia> # Likewise, the `x` key is not used because `x[1]` is an exact match. + getvalue(vals, @varname(x[1])) +1-element Vector{Float64}: + 2.0 + +julia> # No exact match, so the most specific key, i.e. `x[1]`, is used. + getvalue(vals, @varname(x[1][1])) +2.0 +``` """ function getvalue(vals::NamedTuple, vn::VarName{sym}) where {sym} optic = getoptic(vn) From e0adba72395dec37a882e11c149f8c2905c2c220 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Sun, 6 Jul 2025 13:39:22 +0100 Subject: [PATCH 06/13] Simplify logic (can rely on normalisation) --- ext/AbstractPPLDistributionsExt.jl | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/ext/AbstractPPLDistributionsExt.jl b/ext/AbstractPPLDistributionsExt.jl index fd6b979..1e96c86 100644 --- a/ext/AbstractPPLDistributionsExt.jl +++ b/ext/AbstractPPLDistributionsExt.jl @@ -32,11 +32,7 @@ function AbstractPPL.hasvalue( # can return true. sz = size(dist) for idx in Iterators.product(map(Base.OneTo, sz)...) - new_optic = if AbstractPPL.getoptic(vn) === identity - Accessors.IndexLens(idx) - else - Accessors.IndexLens(idx) ∘ AbstractPPL.getoptic(vn) - end + new_optic = Accessors.IndexLens(idx) ∘ AbstractPPL.getoptic(vn) new_vn = VarName{sym}(new_optic) AbstractPPL.hasvalue(vals, new_vn) || return false end From 9291e07d6f88771ac620865ceddfce88ba7a45af Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Sun, 6 Jul 2025 13:53:00 +0100 Subject: [PATCH 07/13] Add tests for composition of head/tail and init/last --- src/varname.jl | 24 +++++++++++++++++++----- test/varname.jl | 48 ++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 67 insertions(+), 5 deletions(-) diff --git a/src/varname.jl b/src/varname.jl index 18cfc06..65f5e7c 100644 --- a/src/varname.jl +++ b/src/varname.jl @@ -967,8 +967,11 @@ string_to_varname(str::AbstractString) = dict_to_varname(JSON.parse(str)) Get the innermost layer of an optic. +For all (normalised) optics, we have that `normalise(_tail(optic) ∘ +_head(optic) == optic)`. + !!! note - Does not perform optic normalisation. You may wish to call + Does not perform optic normalisation on the input. You may wish to call `normalise(optic)` before using this function if the optic you are passing was not obtained from a VarName. @@ -999,8 +1002,11 @@ _head(::typeof(identity)) = identity Get everything but the innermost layer of an optic. +For all (normalised) optics, we have that `normalise(_tail(optic) ∘ +_head(optic) == optic)`. + !!! note - Does not perform optic normalisation. You may wish to call + Does not perform optic normalisation on the input. You may wish to call `normalise(optic)` before using this function if the optic you are passing was not obtained from a VarName. @@ -1031,8 +1037,11 @@ _tail(::typeof(identity)) = identity Get the outermost layer of an optic. +For all (normalised) optics, we have that `normalise(_last(optic) ∘ +_init(optic)) == optic`. + !!! note - Does not perform optic normalisation. You may wish to call + Does not perform optic normalisation on the input. You may wish to call `normalise(optic)` before using this function if the optic you are passing was not obtained from a VarName. @@ -1063,8 +1072,11 @@ _last(::typeof(identity)) = identity Get everything but the outermost layer of an optic. +For all (normalised) optics, we have that `normalise(_last(optic) ∘ +_init(optic)) == optic`. + !!! note - Does not perform optic normalisation. You may wish to call + Does not perform optic normalisation on the input. You may wish to call `normalise(optic)` before using this function if the optic you are passing was not obtained from a VarName. @@ -1084,7 +1096,9 @@ identity (generic function with 1 method) julia> AbstractPPL._init(Accessors.@o _) identity (generic function with 1 method) """ -_init(o::ComposedFunction{Outer,Inner}) where {Outer,Inner} = _init(o.outer) ∘ o.inner +# This one needs normalise because it's going 'against' the direction of the +# linked list (otherwise you will end up with identities scattered throughout) +_init(o::ComposedFunction{Outer,Inner}) where {Outer,Inner} = normalise(_init(o.outer) ∘ o.inner) _init(::Accessors.PropertyLens) = identity _init(::Accessors.IndexLens) = identity _init(::typeof(identity)) = identity diff --git a/test/varname.jl b/test/varname.jl index 40260da..2fbff9b 100644 --- a/test/varname.jl +++ b/test/varname.jl @@ -252,6 +252,54 @@ end @test string_to_varname(varname_to_string(vn)) == vn end + @testset "head, tail, init, last" begin + @testset "specification" begin + @test AbstractPPL._head(@o _.a.b.c) == @o _.a + @test AbstractPPL._tail(@o _.a.b.c) == @o _.b.c + @test AbstractPPL._init(@o _.a.b.c) == @o _.a.b + @test AbstractPPL._last(@o _.a.b.c) == @o _.c + + @test AbstractPPL._head(@o _[1][2][3]) == @o _[1] + @test AbstractPPL._tail(@o _[1][2][3]) == @o _[2][3] + @test AbstractPPL._init(@o _[1][2][3]) == @o _[1][2] + @test AbstractPPL._last(@o _[1][2][3]) == @o _[3] + + @test AbstractPPL._head(@o _.a) == @o _.a + @test AbstractPPL._tail(@o _.a) == identity + @test AbstractPPL._init(@o _.a) == identity + @test AbstractPPL._last(@o _.a) == @o _.a + + @test AbstractPPL._head(@o _[1]) == @o _[1] + @test AbstractPPL._tail(@o _[1]) == identity + @test AbstractPPL._init(@o _[1]) == identity + @test AbstractPPL._last(@o _[1]) == @o _[1] + + @test AbstractPPL._head(identity) == identity + @test AbstractPPL._tail(identity) == identity + @test AbstractPPL._init(identity) == identity + @test AbstractPPL._last(identity) == identity + end + + @testset "composition" begin + varnames = ( + @varname(x), + @varname(x[1]), + @varname(x.a), + @varname(x.a.b), + @varname(x[1].a), + ) + for vn in varnames + optic = getoptic(vn) + @test AbstractPPL.normalise( + AbstractPPL._last(optic) ∘ AbstractPPL._init(optic) + ) == optic + @test AbstractPPL.normalise( + AbstractPPL._tail(optic) ∘ AbstractPPL._head(optic) + ) == optic + end + end + end + @testset "prefix and unprefix" begin @testset "basic cases" begin @test prefix(@varname(y), @varname(x)) == @varname(x.y) From a736e70c833b064cf79c384fe00425095f420eb8 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Sun, 6 Jul 2025 19:47:13 +0100 Subject: [PATCH 08/13] Finish implementing distributions methods --- Project.toml | 4 +- ext/AbstractPPLDistributionsExt.jl | 160 +++++++++++++++++++++++++++-- src/varname.jl | 4 +- test/hasvalue.jl | 55 +++++++++- 4 files changed, 209 insertions(+), 14 deletions(-) diff --git a/Project.toml b/Project.toml index 3adbf5e..129aa20 100644 --- a/Project.toml +++ b/Project.toml @@ -15,15 +15,17 @@ StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" [weakdeps] Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" [extensions] -AbstractPPLDistributionsExt = ["Distributions"] +AbstractPPLDistributionsExt = ["Distributions", "LinearAlgebra"] [compat] AbstractMCMC = "2, 3, 4, 5" Accessors = "0.1" DensityInterface = "0.4" Distributions = "0.25" +LinearAlgebra = "<0.0.1, 1.11" JSON = "0.19 - 0.21" Random = "1.6" StatsBase = "0.32, 0.33, 0.34" diff --git a/ext/AbstractPPLDistributionsExt.jl b/ext/AbstractPPLDistributionsExt.jl index 1e96c86..02682ce 100644 --- a/ext/AbstractPPLDistributionsExt.jl +++ b/ext/AbstractPPLDistributionsExt.jl @@ -2,23 +2,112 @@ module AbstractPPLDistributionsExt using AbstractPPL: AbstractPPL, VarName, Accessors using Distributions: Distributions +using LinearAlgebra: Cholesky, LowerTriangular, UpperTriangular + +#= +This section is copied from Accessors.jl's documentation: +https://juliaobjects.github.io/Accessors.jl/stable/examples/custom_macros/ + +It defines a wrapper that, when called with `set`, mutates the original value +rather than returning a new value. We need this because the non-mutating optics +don't work for triangular matrices (and hence LKJCholesky): see +https://github.com/JuliaObjects/Accessors.jl/issues/203 +=# +struct Lens!{L} + pure::L +end +(l::Lens!)(o) = l.pure(o) +function Accessors.set(o, l::Lens!{<:ComposedFunction}, val) + o_inner = l.pure.inner(o) + return Accessors.set(o_inner, Lens!(l.pure.outer), val) +end +function Accessors.set(o, l::Lens!{Accessors.PropertyLens{prop}}, val) where {prop} + setproperty!(o, prop, val) + return o +end +function Accessors.set(o, l::Lens!{<:Accessors.IndexLens}, val) + o[l.pure.indices...] = val + return o +end + +""" + get_optics(dist::MultivariateDistribution) + get_optics(dist::MatrixDistribution) + get_optics(dist::LKJCholesky) + +Return a complete set of optics for each element of the type returned by `rand(dist)`. +""" +function get_optics( + dist::Union{Distributions.MultivariateDistribution,Distributions.MatrixDistribution} +) + indices = CartesianIndices(size(dist)) + return map(idx -> Accessors.IndexLens(idx.I), indices) +end +function get_optics(dist::Distributions.LKJCholesky) + is_up = dist.uplo == 'U' + cartesian_indices = filter(CartesianIndices(size(dist))) do cartesian_index + i, j = cartesian_index.I + is_up ? i <= j : i >= j + end + # there is an additional layer as we need to access `.L` or `.U` before we + # can index into it + field_lens = is_up ? (Accessors.@o _.U) : (Accessors.@o _.L) + return map(idx -> Accessors.IndexLens(idx.I) ∘ field_lens, cartesian_indices) +end + +""" + make_empty_value(dist::MultivariateDistribution) + make_empty_value(dist::MatrixDistribution) + make_empty_value(dist::LKJCholesky) + +Construct a fresh value filled with zeros that corresponds to the size of `dist`. + +For all distributions that this function accepts, it should hold that +`o(make_empty_value(dist))` is zero for all `o` in `get_optics(dist)`. +""" +function make_empty_value( + dist::Union{Distributions.MultivariateDistribution,Distributions.MatrixDistribution} +) + return zeros(size(dist)) +end +function make_empty_value(dist::Distributions.LKJCholesky) + if dist.uplo == 'U' + return Cholesky(UpperTriangular(zeros(size(dist)))) + else + return Cholesky(LowerTriangular(zeros(size(dist)))) + end +end # TODO(penelopeysm): Figure out tuple / namedtuple distributions, and LKJCholesky (grr) function AbstractPPL.hasvalue( - vals::AbstractDict, vn::VarName, dist::Distributions.Distribution + vals::AbstractDict, + vn::VarName, + dist::Distributions.Distribution; + error_on_incomplete::Bool=false, ) @warn "`hasvalue(vals, vn, dist)` is not implemented for $(typeof(dist)); falling back to `hasvalue(vals, vn)`." return AbstractPPL.hasvalue(vals, vn) end function AbstractPPL.hasvalue( - vals::AbstractDict, vn::VarName, ::Distributions.UnivariateDistribution + vals::AbstractDict, + vn::VarName, + ::Distributions.UnivariateDistribution; + error_on_incomplete::Bool=false, ) + # TODO(penelopeysm): We could also implement a check for the type to catch + # invalid values. Unsure if that is worth it. It may be easier to just let + # the user handle it. return AbstractPPL.hasvalue(vals, vn) end function AbstractPPL.hasvalue( vals::AbstractDict{<:VarName}, vn::VarName{sym}, - dist::Union{Distributions.MultivariateDistribution,Distributions.MatrixDistribution}, + dist::Union{ + Distributions.MultivariateDistribution, + Distributions.MatrixDistribution, + Distributions.LKJCholesky, + }; + error_on_incomplete::Bool=false, ) where {sym} # If `vn` is present as-is, then we are good AbstractPPL.hasvalue(vals, vn) && return true @@ -30,13 +119,66 @@ function AbstractPPL.hasvalue( # To do this, we get the size of the distribution and iterate over all # possible indices. If every index can be found in `subsumed_keys`, then we # can return true. - sz = size(dist) - for idx in Iterators.product(map(Base.OneTo, sz)...) - new_optic = Accessors.IndexLens(idx) ∘ AbstractPPL.getoptic(vn) - new_vn = VarName{sym}(new_optic) - AbstractPPL.hasvalue(vals, new_vn) || return false + optics = get_optics(dist) + original_optic = AbstractPPL.getoptic(vn) + expected_vns = map(o -> VarName{sym}(o ∘ original_optic), optics) + if all(sub_vn -> AbstractPPL.hasvalue(vals, sub_vn), expected_vns) + return true + else + if error_on_incomplete && + any(sub_vn -> AbstractPPL.hasvalue(vals, sub_vn), expected_vns) + error("hasvalue: only partial values for `$vn` found in the values provided") + end + return false + end +end + +function AbstractPPL.getvalue( + vals::AbstractDict, vn::VarName, dist::Distributions.Distribution; +) + @warn "`getvalue(vals, vn, dist)` is not implemented for $(typeof(dist)); falling back to `getvalue(vals, vn)`." + return AbstractPPL.getvalue(vals, vn) +end +function AbstractPPL.getvalue( + vals::AbstractDict, vn::VarName, ::Distributions.UnivariateDistribution; +) + # TODO(penelopeysm): We could also implement a check for the type to catch + # invalid values. Unsure if that is worth it. It may be easier to just let + # the user handle it. + return AbstractPPL.getvalue(vals, vn) +end +function AbstractPPL.getvalue( + vals::AbstractDict{<:VarName}, + vn::VarName{sym}, + dist::Union{ + Distributions.MultivariateDistribution, + Distributions.MatrixDistribution, + Distributions.LKJCholesky, + }; +) where {sym} + # If `vn` is present as-is, then we can just return that + AbstractPPL.hasvalue(vals, vn) && return AbstractPPL.getvalue(vals, vn) + # If not, then we need to start looking inside `vals`, in exactly the + # same way we did for `hasvalue`. + optics = get_optics(dist) + original_optic = AbstractPPL.getoptic(vn) + expected_vns = map(o -> VarName{sym}(o ∘ original_optic), optics) + if all(sub_vn -> AbstractPPL.hasvalue(vals, sub_vn), expected_vns) + # Reconstruct the value index by index. + value = make_empty_value(dist) + for (o, sub_vn) in zip(optics, expected_vns) + # Retrieve the value of this given index + sub_value = AbstractPPL.getvalue(vals, sub_vn) + # Set it inside the value we're reconstructing. + # Note: `o` is normally non-mutating. We have to wrap it in `Lens!` + # to make it mutating, because Cholesky distributions are broken + # by https://github.com/JuliaObjects/Accessors.jl/issues/203. + Accessors.set(value, Lens!(o), sub_value) + end + return value + else + error("getvalue: $(vn) was not found in the values provided") end - return true end end diff --git a/src/varname.jl b/src/varname.jl index 65f5e7c..fca1004 100644 --- a/src/varname.jl +++ b/src/varname.jl @@ -1098,7 +1098,9 @@ identity (generic function with 1 method) """ # This one needs normalise because it's going 'against' the direction of the # linked list (otherwise you will end up with identities scattered throughout) -_init(o::ComposedFunction{Outer,Inner}) where {Outer,Inner} = normalise(_init(o.outer) ∘ o.inner) +function _init(o::ComposedFunction{Outer,Inner}) where {Outer,Inner} + return normalise(_init(o.outer) ∘ o.inner) +end _init(::Accessors.PropertyLens) = identity _init(::Accessors.IndexLens) = identity _init(::typeof(identity)) = identity diff --git a/test/hasvalue.jl b/test/hasvalue.jl index 62bdf7d..18d8f2e 100644 --- a/test/hasvalue.jl +++ b/test/hasvalue.jl @@ -126,7 +126,56 @@ end using Distributions using LinearAlgebra - d = Dict(@varname(x[1]) => 1.0, @varname(x[2]) => 2.0) - @test hasvalue(d, @varname(x), MvNormal(zeros(2), I)) - @test !hasvalue(d, @varname(x), MvNormal(zeros(3), I)) + @testset "univariate" begin + d = Dict(@varname(x) => 1.0, @varname(y) => [[2.0]]) + @test hasvalue(d, @varname(x), Normal()) + @test getvalue(d, @varname(x), Normal()) == 1.0 + @test hasvalue(d, @varname(y[1][1]), Normal()) + @test getvalue(d, @varname(y[1][1]), Normal()) == 2.0 + end + + @testset "multivariate + matrix" begin + d = Dict(@varname(x[1]) => 1.0, @varname(x[2]) => 2.0) + @test hasvalue(d, @varname(x), MvNormal(zeros(1), I)) + @test getvalue(d, @varname(x), MvNormal(zeros(1), I)) == [1.0] + @test hasvalue(d, @varname(x), MvNormal(zeros(2), I)) + @test getvalue(d, @varname(x), MvNormal(zeros(2), I)) == [1.0, 2.0] + @test !hasvalue(d, @varname(x), MvNormal(zeros(3), I)) + @test_throws ErrorException hasvalue( + d, @varname(x), MvNormal(zeros(3), I); error_on_incomplete=true + ) + # If none of the varnames match, it should just return false instead of erroring + @test !hasvalue(d, @varname(y), MvNormal(zeros(2), I); error_on_incomplete=true) + end + + @testset "LKJCholesky :upside_down_smile:" begin + # yes, this isn't a valid Cholesky sample, but whatever + d = Dict( + @varname(x.L[1, 1]) => 1.0, + @varname(x.L[2, 1]) => 2.0, + @varname(x.L[2, 2]) => 3.0, + ) + @test hasvalue(d, @varname(x), LKJCholesky(2, 1.0)) + @test getvalue(d, @varname(x), LKJCholesky(2, 1.0)) == + Cholesky(LowerTriangular([1.0 0.0; 2.0 3.0])) + @test !hasvalue(d, @varname(x), LKJCholesky(3, 1.0)) + @test_throws ErrorException hasvalue( + d, @varname(x), LKJCholesky(3, 1.0); error_on_incomplete=true + ) + @test !hasvalue(d, @varname(y), LKJCholesky(3, 1.0); error_on_incomplete=true) + + d = Dict( + @varname(x.U[1, 1]) => 1.0, + @varname(x.U[1, 2]) => 2.0, + @varname(x.U[2, 2]) => 3.0, + ) + @test hasvalue(d, @varname(x), LKJCholesky(2, 1.0, :U)) + @test getvalue(d, @varname(x), LKJCholesky(2, 1.0, :U)) == + Cholesky(UpperTriangular([1.0 2.0; 0.0 3.0])) + @test !hasvalue(d, @varname(x), LKJCholesky(3, 1.0, :U)) + @test_throws ErrorException hasvalue( + d, @varname(x), LKJCholesky(3, 1.0, :U); error_on_incomplete=true + ) + @test !hasvalue(d, @varname(y), LKJCholesky(3, 1.0, :U); error_on_incomplete=true) + end end From 5a902b014c563c4f0d35a37bddcd8f8c3b82f734 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Sun, 6 Jul 2025 19:54:50 +0100 Subject: [PATCH 09/13] Document --- docs/Project.toml | 1 + docs/make.jl | 3 +- docs/src/api.md | 7 ++++ ext/AbstractPPLDistributionsExt.jl | 62 ++++++++++++++++++++++++++++++ 4 files changed, 72 insertions(+), 1 deletion(-) diff --git a/docs/Project.toml b/docs/Project.toml index dfa65cd..475187f 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -1,2 +1,3 @@ [deps] +Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" diff --git a/docs/make.jl b/docs/make.jl index 33bf21b..c184002 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -1,12 +1,13 @@ using Documenter using AbstractPPL +using Distributions # Doctest setup DocMeta.setdocmeta!(AbstractPPL, :DocTestSetup, :(using AbstractPPL); recursive=true) makedocs(; sitename="AbstractPPL", - modules=[AbstractPPL], + modules=[AbstractPPL, Base.get_extension(AbstractPPL, :AbstractPPLDistributionsExt)], pages=["Home" => "index.md", "API" => "api.md"], checkdocs=:exports, doctest=false, diff --git a/docs/src/api.md b/docs/src/api.md index e7705da..9c05c47 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -21,6 +21,13 @@ prefix unprefix ``` +## Extracting values corresponding to a VarName + +```@docs +hasvalue +getvalue +``` + ## VarName serialisation ```@docs diff --git a/ext/AbstractPPLDistributionsExt.jl b/ext/AbstractPPLDistributionsExt.jl index 02682ce..4b06ce2 100644 --- a/ext/AbstractPPLDistributionsExt.jl +++ b/ext/AbstractPPLDistributionsExt.jl @@ -78,6 +78,42 @@ function make_empty_value(dist::Distributions.LKJCholesky) end end +""" + hasvalue( + vals::AbstractDict, + vn::VarName, + dist::Distribution; + error_on_incomplete::Bool=false + ) + +Check if `vals` contains values for `vn` that is compatible with the +distribution `dist`. + +This is a more general version of `hasvalue(vals, vn)`, in that even if +`vn` itself is not inside `vals`, it further checks if `vals` contains +sub-values of `vn` that can be used to reconstruct `vn` given `dist`. + +The `error_on_incomplete` flag can be used to detect cases where _some_ of +the values needed for `vn` are present, but others are not. This may help +to detect invalid cases where the user has provided e.g. data of the wrong +shape. + +For example: + +```jldoctest; setup=:(using Distributions, LinearAlgebra)) +julia> d = Dict(@varname(x[1]) => 1.0, @varname(x[2]) => 2.0); + +julia> hasvalue(d, @varname(x), MvNormal(zeros(2), I)) +true + +julia> hasvalue(d, @varname(x), MvNormal(zeros(3), I)) +false + +julia> hasvalue(d, @varname(x), MvNormal(zeros(3), I); error_on_incomplete=true) +ERROR: hasvalue: only partial values for `x` found in the values provided +[...] +``` +""" # TODO(penelopeysm): Figure out tuple / namedtuple distributions, and LKJCholesky (grr) function AbstractPPL.hasvalue( vals::AbstractDict, @@ -133,6 +169,32 @@ function AbstractPPL.hasvalue( end end +""" + getvalue(vals::AbstractDict, vn::VarName, dist::Distribution) + +Retrieve the value of `vn` from `vals`, using the distribution `dist` to +reconstruct the value if necessary. + +This is a more general version of `getvalue(vals, vn)`, in that even if `vn` +itself is not inside `vals`, it can still reconstruct the value of `vn` +from sub-values of `vn` that are present in `vals`. + +For example: + +```jldoctest; setup=:(using Distributions, LinearAlgebra)) +julia> d = Dict(@varname(x[1]) => 1.0, @varname(x[2]) => 2.0); + +julia> getvalue(d, @varname(x), MvNormal(zeros(2), I)) +2-element Vector{Float64}: + 1.0 + 2.0 + +julia> # Use `hasvalue` to check for this case before calling `getvalue`. + getvalue(d, @varname(x), MvNormal(zeros(3), I)) +ERROR: getvalue: `x` was not found in the values provided +[...] +``` +""" function AbstractPPL.getvalue( vals::AbstractDict, vn::VarName, dist::Distributions.Distribution; ) From 0e8d2561efbb2507cdb8d7d31ef9f482c9aa2eac Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Sun, 6 Jul 2025 19:56:58 +0100 Subject: [PATCH 10/13] Fix LinearAlgebra version bound --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 129aa20..d6c5c15 100644 --- a/Project.toml +++ b/Project.toml @@ -25,7 +25,7 @@ AbstractMCMC = "2, 3, 4, 5" Accessors = "0.1" DensityInterface = "0.4" Distributions = "0.25" -LinearAlgebra = "<0.0.1, 1.11" +LinearAlgebra = "<0.0.1, 1.10" JSON = "0.19 - 0.21" Random = "1.6" StatsBase = "0.32, 0.33, 0.34" From 29dc9221abf22f12e53efab67aec2be7b243e505 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Sun, 6 Jul 2025 19:58:27 +0100 Subject: [PATCH 11/13] Try to fix documentation for extension (why is this so complicated...) --- docs/Project.toml | 1 + docs/make.jl | 3 ++- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/docs/Project.toml b/docs/Project.toml index 475187f..a39559b 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -1,3 +1,4 @@ [deps] Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" diff --git a/docs/make.jl b/docs/make.jl index c184002..d3dbe83 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -1,6 +1,7 @@ using Documenter using AbstractPPL -using Distributions +# trigger DistributionsExt loading +using Distributions, LinearAlgebra # Doctest setup DocMeta.setdocmeta!(AbstractPPL, :DocTestSetup, :(using AbstractPPL); recursive=true) From a998af613a42659eeda3115b75e7e429564fd467 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Sun, 6 Jul 2025 20:02:41 +0100 Subject: [PATCH 12/13] Fix extension documentation --- docs/Project.toml | 1 + ext/AbstractPPLDistributionsExt.jl | 1 - 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/Project.toml b/docs/Project.toml index a39559b..15b2ec4 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -1,4 +1,5 @@ [deps] +AbstractPPL = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" diff --git a/ext/AbstractPPLDistributionsExt.jl b/ext/AbstractPPLDistributionsExt.jl index 4b06ce2..85cde86 100644 --- a/ext/AbstractPPLDistributionsExt.jl +++ b/ext/AbstractPPLDistributionsExt.jl @@ -114,7 +114,6 @@ ERROR: hasvalue: only partial values for `x` found in the values provided [...] ``` """ -# TODO(penelopeysm): Figure out tuple / namedtuple distributions, and LKJCholesky (grr) function AbstractPPL.hasvalue( vals::AbstractDict, vn::VarName, From 6a01588ef87dac4f129407e48d536725ebd173fa Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Sun, 6 Jul 2025 20:10:45 +0100 Subject: [PATCH 13/13] Implement fallback {has,get}value methods for NamedTuple + Distribution --- ext/AbstractPPLDistributionsExt.jl | 35 ++++++++++++++++++++++++++++-- 1 file changed, 33 insertions(+), 2 deletions(-) diff --git a/ext/AbstractPPLDistributionsExt.jl b/ext/AbstractPPLDistributionsExt.jl index 85cde86..8eb920a 100644 --- a/ext/AbstractPPLDistributionsExt.jl +++ b/ext/AbstractPPLDistributionsExt.jl @@ -80,7 +80,7 @@ end """ hasvalue( - vals::AbstractDict, + vals::Union{AbstractDict,NamedTuple}, vn::VarName, dist::Distribution; error_on_incomplete::Bool=false @@ -98,6 +98,11 @@ the values needed for `vn` are present, but others are not. This may help to detect invalid cases where the user has provided e.g. data of the wrong shape. +Note that this check is only possible if a Dict is passed, because the key type +of a NamedTuple (i.e., Symbol) is not rich enough to carry indexing +information. If this method is called with a NamedTuple, it will just defer +to `hasvalue(vals, vn)`. + For example: ```jldoctest; setup=:(using Distributions, LinearAlgebra)) @@ -114,6 +119,16 @@ ERROR: hasvalue: only partial values for `x` found in the values provided [...] ``` """ +function AbstractPPL.hasvalue( + vals::NamedTuple, + vn::VarName, + dist::Distributions.Distribution; + error_on_incomplete::Bool=false, +) + # NamedTuples can't have such complicated hierarchies, so it's safe to + # defer to the simpler `hasvalue(vals, vn)`. + return hasvalue(vals, vn) +end function AbstractPPL.hasvalue( vals::AbstractDict, vn::VarName, @@ -169,7 +184,11 @@ function AbstractPPL.hasvalue( end """ - getvalue(vals::AbstractDict, vn::VarName, dist::Distribution) + getvalue( + vals::Union{AbstractDict,NamedTuple}, + vn::VarName, + dist::Distribution + ) Retrieve the value of `vn` from `vals`, using the distribution `dist` to reconstruct the value if necessary. @@ -178,6 +197,11 @@ This is a more general version of `getvalue(vals, vn)`, in that even if `vn` itself is not inside `vals`, it can still reconstruct the value of `vn` from sub-values of `vn` that are present in `vals`. +Note that this reconstruction is only possible if a Dict is passed, because the +key type of a NamedTuple (i.e., Symbol) is not rich enough to carry indexing +information. If this method is called with a NamedTuple, it will just defer +to `getvalue(vals, vn)`. + For example: ```jldoctest; setup=:(using Distributions, LinearAlgebra)) @@ -194,6 +218,13 @@ ERROR: getvalue: `x` was not found in the values provided [...] ``` """ +function AbstractPPL.getvalue( + vals::NamedTuple, vn::VarName, dist::Distributions.Distribution +) + # NamedTuples can't have such complicated hierarchies, so it's safe to + # defer to the simpler `getvalue(vals, vn)`. + return getvalue(vals, vn) +end function AbstractPPL.getvalue( vals::AbstractDict, vn::VarName, dist::Distributions.Distribution; )