|
| 1 | + |
| 2 | +""" |
| 3 | + bijector(model::Model[, sym2ranges = Val(false)]) |
| 4 | +
|
| 5 | +Returns a `Stacked <: Bijector` which maps from the support of the posterior to ℝᵈ with `d` |
| 6 | +denoting the dimensionality of the latent variables. |
| 7 | +""" |
| 8 | +function Bijectors.bijector( |
| 9 | + model::DynamicPPL.Model, |
| 10 | + (::Val{sym2ranges})=Val(false); |
| 11 | + varinfo=DynamicPPL.VarInfo(model), |
| 12 | +) where {sym2ranges} |
| 13 | + dists = vcat([varinfo.metadata[sym].dists for sym in keys(varinfo.metadata)]...) |
| 14 | + |
| 15 | + num_ranges = sum([ |
| 16 | + length(varinfo.metadata[sym].ranges) for sym in keys(varinfo.metadata) |
| 17 | + ]) |
| 18 | + ranges = Vector{UnitRange{Int}}(undef, num_ranges) |
| 19 | + idx = 0 |
| 20 | + range_idx = 1 |
| 21 | + |
| 22 | + # ranges might be discontinuous => values are vectors of ranges rather than just ranges |
| 23 | + sym_lookup = Dict{Symbol,Vector{UnitRange{Int}}}() |
| 24 | + for sym in keys(varinfo.metadata) |
| 25 | + sym_lookup[sym] = Vector{UnitRange{Int}}() |
| 26 | + for r in varinfo.metadata[sym].ranges |
| 27 | + ranges[range_idx] = idx .+ r |
| 28 | + push!(sym_lookup[sym], ranges[range_idx]) |
| 29 | + range_idx += 1 |
| 30 | + end |
| 31 | + |
| 32 | + idx += varinfo.metadata[sym].ranges[end][end] |
| 33 | + end |
| 34 | + |
| 35 | + bs = map(tuple(dists...)) do d |
| 36 | + b = Bijectors.bijector(d) |
| 37 | + if d isa Distributions.UnivariateDistribution |
| 38 | + b |
| 39 | + else |
| 40 | + # Wrap a bijector `f` such that it operates on vectors of length `prod(in_size)` |
| 41 | + # and produces a vector of length `prod(Bijectors.output(f, in_size))`. |
| 42 | + in_size = size(d) |
| 43 | + vec_in_length = prod(in_size) |
| 44 | + reshape_inner = Bijectors.Reshape((vec_in_length,), in_size) |
| 45 | + out_size = Bijectors.output_size(b, in_size) |
| 46 | + vec_out_length = prod(out_size) |
| 47 | + reshape_outer = Bijectors.Reshape(out_size, (vec_out_length,)) |
| 48 | + reshape_outer ∘ b ∘ reshape_inner |
| 49 | + end |
| 50 | + end |
| 51 | + |
| 52 | + if sym2ranges |
| 53 | + return ( |
| 54 | + Bijectors.Stacked(bs, ranges), |
| 55 | + (; collect(zip(keys(sym_lookup), values(sym_lookup)))...), |
| 56 | + ) |
| 57 | + else |
| 58 | + return Bijectors.Stacked(bs, ranges) |
| 59 | + end |
| 60 | +end |
0 commit comments