|
| 1 | +""" |
| 2 | + canview(optic, container) |
| 3 | +
|
| 4 | +Return `true` if `optic` can be used to view `container`, and `false` otherwise. |
| 5 | +
|
| 6 | +# Examples |
| 7 | +```jldoctest; setup=:(using Accessors) |
| 8 | +julia> AbstractPPL.canview(@o(_.a), (a = 1.0, )) |
| 9 | +true |
| 10 | +
|
| 11 | +julia> AbstractPPL.canview(@o(_.a), (b = 1.0, )) # property `a` does not exist |
| 12 | +false |
| 13 | +
|
| 14 | +julia> AbstractPPL.canview(@o(_.a[1]), (a = [1.0, 2.0], )) |
| 15 | +true |
| 16 | +
|
| 17 | +julia> AbstractPPL.canview(@o(_.a[3]), (a = [1.0, 2.0], )) # out of bounds |
| 18 | +false |
| 19 | +``` |
| 20 | +""" |
| 21 | +canview(optic, container) = false |
| 22 | +canview(::typeof(identity), _) = true |
| 23 | +function canview(::Accessors.PropertyLens{field}, x) where {field} |
| 24 | + return hasproperty(x, field) |
| 25 | +end |
| 26 | + |
| 27 | +# `IndexLens`: only relevant if `x` supports indexing. |
| 28 | +canview(optic::Accessors.IndexLens, x) = false |
| 29 | +function canview(optic::Accessors.IndexLens, x::AbstractArray) |
| 30 | + return checkbounds(Bool, x, optic.indices...) |
| 31 | +end |
| 32 | + |
| 33 | +# `ComposedFunction`: check that we can view `.inner` and `.outer`, but using |
| 34 | +# value extracted using `.inner`. |
| 35 | +function canview(optic::ComposedFunction, x) |
| 36 | + return canview(optic.inner, x) && canview(optic.outer, optic.inner(x)) |
| 37 | +end |
| 38 | + |
| 39 | +""" |
| 40 | + getvalue(vals::NamedTuple, vn::VarName) |
| 41 | + getvalue(vals::AbstractDict{<:VarName}, vn::VarName) |
| 42 | +
|
| 43 | +Return the value(s) in `vals` represented by `vn`. |
| 44 | +
|
| 45 | +# Examples |
| 46 | +
|
| 47 | +For `NamedTuple`: |
| 48 | +
|
| 49 | +```jldoctest |
| 50 | +julia> vals = (x = [1.0],); |
| 51 | +
|
| 52 | +julia> getvalue(vals, @varname(x)) # same as `getindex` |
| 53 | +1-element Vector{Float64}: |
| 54 | + 1.0 |
| 55 | +
|
| 56 | +julia> getvalue(vals, @varname(x[1])) # different from `getindex` |
| 57 | +1.0 |
| 58 | +
|
| 59 | +julia> getvalue(vals, @varname(x[2])) |
| 60 | +ERROR: BoundsError: attempt to access 1-element Vector{Float64} at index [2] |
| 61 | +[...] |
| 62 | +``` |
| 63 | +
|
| 64 | +For `AbstractDict`: |
| 65 | +
|
| 66 | +```jldoctest |
| 67 | +julia> vals = Dict(@varname(x) => [1.0]); |
| 68 | +
|
| 69 | +julia> getvalue(vals, @varname(x)) # same as `getindex` |
| 70 | +1-element Vector{Float64}: |
| 71 | + 1.0 |
| 72 | +
|
| 73 | +julia> getvalue(vals, @varname(x[1])) # different from `getindex` |
| 74 | +1.0 |
| 75 | +
|
| 76 | +julia> getvalue(vals, @varname(x[2])) |
| 77 | +ERROR: BoundsError: attempt to access 1-element Vector{Float64} at index [2] |
| 78 | +[...] |
| 79 | +``` |
| 80 | +
|
| 81 | +In the `AbstractDict` case we can also have keys such as `v[1]`: |
| 82 | +
|
| 83 | +```jldoctest |
| 84 | +julia> vals = Dict(@varname(x[1]) => [1.0,]); |
| 85 | +
|
| 86 | +julia> getvalue(vals, @varname(x[1])) # same as `getindex` |
| 87 | +1-element Vector{Float64}: |
| 88 | + 1.0 |
| 89 | +
|
| 90 | +julia> getvalue(vals, @varname(x[1][1])) # different from `getindex` |
| 91 | +1.0 |
| 92 | +
|
| 93 | +julia> getvalue(vals, @varname(x[1][2])) |
| 94 | +ERROR: BoundsError: attempt to access 1-element Vector{Float64} at index [2] |
| 95 | +[...] |
| 96 | +
|
| 97 | +julia> getvalue(vals, @varname(x[2][1])) |
| 98 | +ERROR: KeyError: key x[2][1] not found |
| 99 | +[...] |
| 100 | +``` |
| 101 | +""" |
| 102 | +getvalue(vals::NamedTuple, vn::VarName) = get(vals, vn) |
| 103 | +getvalue(vals::AbstractDict, vn::VarName) = nested_getindex(vals, vn) |
| 104 | + |
| 105 | +""" |
| 106 | + hasvalue(vals::NamedTuple, vn::VarName) |
| 107 | + hasvalue(vals::AbstractDict{<:VarName}, vn::VarName) |
| 108 | +
|
| 109 | +Determine whether `vals` contains a value for a given `vn`. |
| 110 | +
|
| 111 | +# Examples |
| 112 | +With `x` as a `NamedTuple`: |
| 113 | +
|
| 114 | +```jldoctest |
| 115 | +julia> hasvalue((x = 1.0, ), @varname(x)) |
| 116 | +true |
| 117 | +
|
| 118 | +julia> hasvalue((x = 1.0, ), @varname(x[1])) |
| 119 | +false |
| 120 | +
|
| 121 | +julia> hasvalue((x = [1.0],), @varname(x)) |
| 122 | +true |
| 123 | +
|
| 124 | +julia> hasvalue((x = [1.0],), @varname(x[1])) |
| 125 | +true |
| 126 | +
|
| 127 | +julia> hasvalue((x = [1.0],), @varname(x[2])) |
| 128 | +false |
| 129 | +``` |
| 130 | +
|
| 131 | +With `x` as a `AbstractDict`: |
| 132 | +
|
| 133 | +```jldoctest |
| 134 | +julia> hasvalue(Dict(@varname(x) => 1.0, ), @varname(x)) |
| 135 | +true |
| 136 | +
|
| 137 | +julia> hasvalue(Dict(@varname(x) => 1.0, ), @varname(x[1])) |
| 138 | +false |
| 139 | +
|
| 140 | +julia> hasvalue(Dict(@varname(x) => [1.0]), @varname(x)) |
| 141 | +true |
| 142 | +
|
| 143 | +julia> hasvalue(Dict(@varname(x) => [1.0]), @varname(x[1])) |
| 144 | +true |
| 145 | +
|
| 146 | +julia> hasvalue(Dict(@varname(x) => [1.0]), @varname(x[2])) |
| 147 | +false |
| 148 | +``` |
| 149 | +
|
| 150 | +In the `AbstractDict` case we can also have keys such as `v[1]`: |
| 151 | +
|
| 152 | +```jldoctest |
| 153 | +julia> vals = Dict(@varname(x[1]) => [1.0,]); |
| 154 | +
|
| 155 | +julia> hasvalue(vals, @varname(x[1])) # same as `haskey` |
| 156 | +true |
| 157 | +
|
| 158 | +julia> hasvalue(vals, @varname(x[1][1])) # different from `haskey` |
| 159 | +true |
| 160 | +
|
| 161 | +julia> hasvalue(vals, @varname(x[1][2])) |
| 162 | +false |
| 163 | +
|
| 164 | +julia> hasvalue(vals, @varname(x[2][1])) |
| 165 | +false |
| 166 | +``` |
| 167 | +""" |
| 168 | +function hasvalue(vals::NamedTuple, vn::VarName{sym}) where {sym} |
| 169 | + return haskey(vals, sym) && canview(getoptic(vn), getproperty(vals, sym)) |
| 170 | +end |
| 171 | + |
| 172 | +# For the Dict case, it is more complicated. There are two cases: |
| 173 | +# 1. `vn` itself is already a key of `vals` (the easy case) |
| 174 | +# 2. `vn` is not a key of `vals`, but some parent of `vn` is a key of `vals` |
| 175 | +# (the harder case). For example, if `vn` is `x[1][2]`, then we need to |
| 176 | +# check if either `x` or `x[1]` is a key of `vals`, and if so, whether |
| 177 | +# we can index into the corresponding value. |
| 178 | +function hasvalue(vals::AbstractDict{<:VarName}, vn::VarName{sym}) where {sym} |
| 179 | + # First we check if `vn` is present as is. |
| 180 | + haskey(vals, vn) && return true |
| 181 | + |
| 182 | + # Otherwise, we start by testing the bare `vn` (e.g., if `vn` is `x[1][2]`, |
| 183 | + # we start by checking if `x` is present). We will then keep adding optics |
| 184 | + # to `test_optic`, either until we find a key that is present, or until we |
| 185 | + # run out of optics to test (which is determined by _inner(test_optic) == |
| 186 | + # identity). |
| 187 | + test_vn = VarName{sym}() |
| 188 | + test_optic = getoptic(vn) |
| 189 | + |
| 190 | + while _inner(test_optic) != identity |
| 191 | + @show test_vn, test_optic |
| 192 | + if haskey(vals, test_vn) |
| 193 | + @show canview(test_optic, vals[test_vn]) |
| 194 | + end |
| 195 | + if haskey(vals, test_vn) && canview(test_optic, vals[test_vn]) |
| 196 | + return true |
| 197 | + else |
| 198 | + # Move the innermost optic into test_vn |
| 199 | + test_optic_outer = _outer(test_optic) |
| 200 | + test_optic_inner = _inner(test_optic) |
| 201 | + test_vn = VarName{sym}(test_optic_inner ∘ getoptic(test_vn)) |
| 202 | + test_optic = test_optic_outer |
| 203 | + end |
| 204 | + end |
| 205 | + return false |
| 206 | +end |
| 207 | +# TODO(penelopeysm): Figure out tuple / namedtuple distributions, and LKJCholesky (grr) |
| 208 | +# function hasvalue(vals::AbstractDict, vn::VarName, dist::Distribution) |
| 209 | +# @warn "`hasvalue(vals, vn, dist)` is not implemented for $(typeof(dist)); falling back to `hasvalue(vals, vn)`." |
| 210 | +# return hasvalue(vals, vn) |
| 211 | +# end |
| 212 | +# hasvalue(vals::AbstractDict, vn::VarName, ::UnivariateDistribution) = hasvalue(vals, vn) |
| 213 | +# function hasvalue( |
| 214 | +# vals::AbstractDict{<:VarName}, |
| 215 | +# vn::VarName{sym}, |
| 216 | +# dist::Union{MultivariateDistribution,MatrixDistribution}, |
| 217 | +# ) where {sym} |
| 218 | +# # If `vn` is present as-is, then we are good |
| 219 | +# hasvalue(vals, vn) && return true |
| 220 | +# # If not, then we need to check inside `vals` to see if a subset of |
| 221 | +# # `vals` is enough to reconstruct `vn`. For example, if `vals` contains |
| 222 | +# # `x[1]` and `x[2]`, and `dist` is `MvNormal(zeros(2), I)`, then we |
| 223 | +# # can reconstruct `x`. If `dist` is `MvNormal(zeros(3), I)`, then we |
| 224 | +# # can't. |
| 225 | +# # To do this, we get the size of the distribution and iterate over all |
| 226 | +# # possible indices. If every index can be found in `subsumed_keys`, then we |
| 227 | +# # can return true. |
| 228 | +# sz = size(dist) |
| 229 | +# for idx in Iterators.product(map(Base.OneTo, sz)...) |
| 230 | +# new_optic = if getoptic(vn) === identity |
| 231 | +# Accessors.IndexLens(idx) |
| 232 | +# else |
| 233 | +# Accessors.IndexLens(idx) ∘ getoptic(vn) |
| 234 | +# end |
| 235 | +# new_vn = VarName{sym}(new_optic) |
| 236 | +# hasvalue(vals, new_vn) || return false |
| 237 | +# end |
| 238 | +# return true |
| 239 | +# end |
| 240 | + |
| 241 | +# """ |
| 242 | +# nested_getindex(values::AbstractDict, vn::VarName) |
| 243 | +# |
| 244 | +# Return value corresponding to `vn` in `values` by also looking |
| 245 | +# in the the actual values of the dict. |
| 246 | +# """ |
| 247 | +# function nested_getindex(values::AbstractDict, vn::VarName) |
| 248 | +# maybeval = get(values, vn, nothing) |
| 249 | +# if maybeval !== nothing |
| 250 | +# return maybeval |
| 251 | +# end |
| 252 | +# |
| 253 | +# # Split the optic into the key / `parent` and the extraction optic / `child`. |
| 254 | +# parent, child, issuccess = splitoptic(getoptic(vn)) do optic |
| 255 | +# o = optic === nothing ? identity : optic |
| 256 | +# haskey(values, VarName(vn, o)) |
| 257 | +# end |
| 258 | +# # When combined with `VarInfo`, `nothing` is equivalent to `identity`. |
| 259 | +# keyoptic = parent === nothing ? identity : parent |
| 260 | +# |
| 261 | +# # If we found a valid split, then we can extract the value. |
| 262 | +# if !issuccess |
| 263 | +# # At this point we just throw an error since the key could not be found. |
| 264 | +# throw(KeyError(vn)) |
| 265 | +# end |
| 266 | +# |
| 267 | +# # TODO: Should we also check that we `canview` the extracted `value` |
| 268 | +# # rather than just let it fail upon `get` call? |
| 269 | +# value = values[VarName(vn, keyoptic)] |
| 270 | +# return child(value) |
| 271 | +# end |
0 commit comments