Skip to content

Commit 8c3bff4

Browse files
committed
Clean new tests up a bit
1 parent a3bc52e commit 8c3bff4

File tree

1 file changed

+20
-9
lines changed

1 file changed

+20
-9
lines changed

test/submodels.jl

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,8 @@ using Test
2929
@test Set(keys(VarInfo(outer()))) == Set([@varname(a.x), @varname(a.y)])
3030

3131
# With conditioning/fixing
32-
@testset "$model" for model in [with_inner_op, with_outer_op]
32+
models = [("inner", with_inner_op), ("outer", with_outer_op)]
33+
@testset "$name" for (name, model) in models
3334
# Test that the value was correctly set
3435
@test model()[1] == x_val
3536
# Test that the logp was correctly set
@@ -44,7 +45,7 @@ using Test
4445
@model function inner()
4546
x ~ Normal()
4647
y ~ Normal()
47-
return x
48+
return (x, y)
4849
end
4950
@model function outer()
5051
return a ~ to_submodel(inner(), false)
@@ -60,7 +61,8 @@ using Test
6061
@test Set(keys(VarInfo(outer()))) == Set([@varname(x), @varname(y)])
6162

6263
# With conditioning/fixing
63-
@testset "$model" for model in [with_inner_op, with_outer_op]
64+
models = [("inner", with_inner_op), ("outer", with_outer_op)]
65+
@testset "$name" for (name, model) in models
6466
# Test that the value was correctly set
6567
@test model()[1] == x_val
6668
# Test that the logp was correctly set
@@ -75,7 +77,7 @@ using Test
7577
@model function inner()
7678
x ~ Normal()
7779
y ~ Normal()
78-
return x
80+
return (x, y)
7981
end
8082
@model function outer()
8183
return a ~ to_submodel(prefix(inner(), :b), false)
@@ -91,7 +93,8 @@ using Test
9193
@test Set(keys(VarInfo(outer()))) == Set([@varname(b.x), @varname(b.y)])
9294

9395
# With conditioning/fixing
94-
@testset "$model" for model in [with_inner_op, with_outer_op]
96+
models = [("inner", with_inner_op), ("outer", with_outer_op)]
97+
@testset "$name" for (name, model) in models
9598
# Test that the value was correctly set
9699
@test model()[1] == x_val
97100
# Test that the logp was correctly set
@@ -115,18 +118,20 @@ using Test
115118
end
116119

117120
# No conditioning
118-
@test Set(keys(VarInfo(h()))) == Set([@varname(a.b.x), @varname(a.b.y)])
121+
vi = VarInfo(h())
122+
@test Set(keys(vi)) == Set([@varname(a.b.x), @varname(a.b.y)])
123+
@test getlogp(vi) ==
124+
logpdf(Normal(), vi[@varname(a.b.x)]) +
125+
logpdf(Normal(), vi[@varname(a.b.y)])
119126

120127
# Conditioning/fixing at the top level
121128
op_h = op(h(), (@varname(a.b.x) => x_val))
122-
@test Set(keys(VarInfo(op_h))) == Set([@varname(a.b.y)])
123129

124130
# Conditioning/fixing at the second level
125131
op_g = op(g(), (@varname(b.x) => x_val))
126132
@model function h2()
127133
return a ~ to_submodel(op_g)
128134
end
129-
@test Set(keys(VarInfo(h2()))) == Set([@varname(a.b.y)])
130135

131136
# Conditioning/fixing at the very bottom
132137
op_f = op(f(), (@varname(x) => x_val))
@@ -136,7 +141,13 @@ using Test
136141
@model function h3()
137142
return a ~ to_submodel(g2())
138143
end
139-
@test Set(keys(VarInfo(h3()))) == Set([@varname(a.b.y)])
144+
145+
models = [("top", op_h), ("middle", h2()), ("bottom", h3())]
146+
@testset "$name" for (name, model) in models
147+
vi = VarInfo(model)
148+
@test Set(keys(vi)) == Set([@varname(a.b.y)])
149+
@test getlogp(vi) == x_logp + logpdf(Normal(), vi[@varname(a.b.y)])
150+
end
140151
end
141152
end
142153

0 commit comments

Comments
 (0)