Skip to content

Commit b0ff259

Browse files
authored
Move bijector(dynamicpplmodel) from Turing.jl (#920)
* move `bijector(dynamicpplmodel)` from Turing.jl * add missing test file * add missing file * increment patch version * add changelog
1 parent cdeb657 commit b0ff259

File tree

6 files changed

+95
-1
lines changed

6 files changed

+95
-1
lines changed

HISTORY.md

+4
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
# DynamicPPL Changelog
22

3+
## 0.36.3
4+
5+
Moved the `bijector(model)`, where `model` is a `DynamicPPL.Model`, function from the Turing main repo.
6+
37
## 0.36.2
48

59
Improved docstrings for AD testing utilities.

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "DynamicPPL"
22
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
3-
version = "0.36.2"
3+
version = "0.36.3"
44

55
[deps]
66
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

src/DynamicPPL.jl

+1
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,7 @@ include("logdensityfunction.jl")
179179
include("model_utils.jl")
180180
include("extract_priors.jl")
181181
include("values_as_in_model.jl")
182+
include("bijector.jl")
182183

183184
include("debug_utils.jl")
184185
using .DebugUtils

src/bijector.jl

+60
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
2+
"""
3+
bijector(model::Model[, sym2ranges = Val(false)])
4+
5+
Returns a `Stacked <: Bijector` which maps from the support of the posterior to ℝᵈ with `d`
6+
denoting the dimensionality of the latent variables.
7+
"""
8+
function Bijectors.bijector(
9+
model::DynamicPPL.Model,
10+
(::Val{sym2ranges})=Val(false);
11+
varinfo=DynamicPPL.VarInfo(model),
12+
) where {sym2ranges}
13+
dists = vcat([varinfo.metadata[sym].dists for sym in keys(varinfo.metadata)]...)
14+
15+
num_ranges = sum([
16+
length(varinfo.metadata[sym].ranges) for sym in keys(varinfo.metadata)
17+
])
18+
ranges = Vector{UnitRange{Int}}(undef, num_ranges)
19+
idx = 0
20+
range_idx = 1
21+
22+
# ranges might be discontinuous => values are vectors of ranges rather than just ranges
23+
sym_lookup = Dict{Symbol,Vector{UnitRange{Int}}}()
24+
for sym in keys(varinfo.metadata)
25+
sym_lookup[sym] = Vector{UnitRange{Int}}()
26+
for r in varinfo.metadata[sym].ranges
27+
ranges[range_idx] = idx .+ r
28+
push!(sym_lookup[sym], ranges[range_idx])
29+
range_idx += 1
30+
end
31+
32+
idx += varinfo.metadata[sym].ranges[end][end]
33+
end
34+
35+
bs = map(tuple(dists...)) do d
36+
b = Bijectors.bijector(d)
37+
if d isa Distributions.UnivariateDistribution
38+
b
39+
else
40+
# Wrap a bijector `f` such that it operates on vectors of length `prod(in_size)`
41+
# and produces a vector of length `prod(Bijectors.output(f, in_size))`.
42+
in_size = size(d)
43+
vec_in_length = prod(in_size)
44+
reshape_inner = Bijectors.Reshape((vec_in_length,), in_size)
45+
out_size = Bijectors.output_size(b, in_size)
46+
vec_out_length = prod(out_size)
47+
reshape_outer = Bijectors.Reshape(out_size, (vec_out_length,))
48+
reshape_outer b reshape_inner
49+
end
50+
end
51+
52+
if sym2ranges
53+
return (
54+
Bijectors.Stacked(bs, ranges),
55+
(; collect(zip(keys(sym_lookup), values(sym_lookup)))...),
56+
)
57+
else
58+
return Bijectors.Stacked(bs, ranges)
59+
end
60+
end

test/bijector.jl

+28
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
2+
@testset "bijector.jl" begin
3+
@testset "bijector" begin
4+
@model function test()
5+
m ~ Normal()
6+
s ~ InverseGamma(3, 3)
7+
return c ~ Dirichlet([1.0, 1.0])
8+
end
9+
10+
m = test()
11+
b = bijector(m)
12+
13+
# m ∈ ℝ, s ∈ ℝ+, c ∈ 2-simplex
14+
# check dimensionalities and ranges
15+
@test b.length_in == 4
16+
@test b.length_out == 3
17+
@test b.ranges_in == [1:1, 2:2, 3:4]
18+
@test b.ranges_out == [1:1, 2:2, 3:3]
19+
@test b.ranges_out == [1:1, 2:2, 3:3]
20+
21+
# check support of mapped variables
22+
binv = inverse(b)
23+
zs = mapslices(binv, randn(b.length_out, 10000); dims=1)
24+
25+
@test all(zs[2, :] .≥ 0)
26+
@test all(sum(zs[3:4, :]; dims=1) .≈ 1.0)
27+
end
28+
end

test/runtests.jl

+1
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ include("test_util.jl")
6868
include("debug_utils.jl")
6969
include("deprecated.jl")
7070
include("submodels.jl")
71+
include("bijector.jl")
7172
end
7273

7374
if GROUP == "All" || GROUP == "Group2"

0 commit comments

Comments
 (0)