Skip to content

Commit dcd24e7

Browse files
committed
Move src/test_utils and test/test_util to DynamicPPLTestExt
1 parent f5890a1 commit dcd24e7

21 files changed

+287
-273
lines changed

Project.toml

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
3131
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
3232
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
3333
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
34+
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
3435
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
3536

3637
[extensions]
@@ -39,6 +40,7 @@ DynamicPPLEnzymeCoreExt = ["EnzymeCore"]
3940
DynamicPPLForwardDiffExt = ["ForwardDiff"]
4041
DynamicPPLMCMCChainsExt = ["MCMCChains"]
4142
DynamicPPLReverseDiffExt = ["ReverseDiff"]
43+
DynamicPPLTestExt = ["Test"]
4244
DynamicPPLZygoteRulesExt = ["ZygoteRules"]
4345

4446
[compat]
@@ -67,11 +69,3 @@ ReverseDiff = "1"
6769
Test = "1.6"
6870
ZygoteRules = "0.2"
6971
julia = "1.10"
70-
71-
[extras]
72-
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
73-
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
74-
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
75-
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
76-
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
77-
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"

ext/DynamicPPLTestExt.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
module DynamicPPLTestExt
2+
3+
using DynamicPPL: DynamicPPL
4+
using Test: @test, @testset, @test_throws, @test_broken
5+
6+
include("DynamicPPLTestExt/utils.jl")
7+
8+
end

src/test_utils.jl renamed to ext/DynamicPPLTestExt/utils.jl

Lines changed: 124 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,8 @@
1-
module TestUtils
1+
module TestExtUtils
2+
3+
###################################################
4+
# These used to be in DPPL/src/test_utils.jl ######
5+
###################################################
26

37
using AbstractMCMC
48
using DynamicPPL
@@ -1097,4 +1101,123 @@ function DynamicPPL.dot_tilde_observe(
10971101
return logp * context.mod, vi
10981102
end
10991103

1104+
1105+
1106+
###################################################
1107+
# These used to be in DPPL/test/test_util.jl ######
1108+
###################################################
1109+
1110+
# default model
1111+
@model function gdemo_d()
1112+
s ~ InverseGamma(2, 3)
1113+
m ~ Normal(0, sqrt(s))
1114+
1.5 ~ Normal(m, sqrt(s))
1115+
2.0 ~ Normal(m, sqrt(s))
1116+
return s, m
1117+
end
1118+
const gdemo_default = gdemo_d()
1119+
1120+
function test_model_ad(model, logp_manual)
1121+
vi = VarInfo(model)
1122+
x = DynamicPPL.getall(vi)
1123+
1124+
# Log probabilities using the model.
1125+
= DynamicPPL.LogDensityFunction(model, vi)
1126+
logp_model = Base.Fix1(LogDensityProblems.logdensity, ℓ)
1127+
1128+
# Check that both functions return the same values.
1129+
lp = logp_manual(x)
1130+
@test logp_model(x) lp
1131+
1132+
# Gradients based on the manual implementation.
1133+
grad = ForwardDiff.gradient(logp_manual, x)
1134+
1135+
y, back = Tracker.forward(logp_manual, x)
1136+
@test Tracker.data(y) lp
1137+
@test Tracker.data(back(1)[1]) grad
1138+
1139+
y, back = Zygote.pullback(logp_manual, x)
1140+
@test y lp
1141+
@test back(1)[1] grad
1142+
1143+
# Gradients based on the model.
1144+
@test ForwardDiff.gradient(logp_model, x) grad
1145+
1146+
y, back = Tracker.forward(logp_model, x)
1147+
@test Tracker.data(y) lp
1148+
@test Tracker.data(back(1)[1]) grad
1149+
1150+
y, back = Zygote.pullback(logp_model, x)
1151+
@test y lp
1152+
@test back(1)[1] grad
1153+
end
1154+
1155+
"""
1156+
test_setval!(model, chain; sample_idx = 1, chain_idx = 1)
1157+
1158+
Test `setval!` on `model` and `chain`.
1159+
1160+
Worth noting that this only supports models containing symbols of the forms
1161+
`m`, `m[1]`, `m[1, 2]`, not `m[1][1]`, etc.
1162+
"""
1163+
function test_setval!(model, chain; sample_idx=1, chain_idx=1)
1164+
var_info = VarInfo(model)
1165+
spl = SampleFromPrior()
1166+
θ_old = var_info[spl]
1167+
DynamicPPL.setval!(var_info, chain, sample_idx, chain_idx)
1168+
θ_new = var_info[spl]
1169+
@test θ_old != θ_new
1170+
vals = DynamicPPL.values_as(var_info, OrderedDict)
1171+
iters = map(DynamicPPL.varname_and_value_leaves, keys(vals), values(vals))
1172+
for (n, v) in mapreduce(collect, vcat, iters)
1173+
n = string(n)
1174+
if Symbol(n) keys(chain)
1175+
# Assume it's a group
1176+
chain_val = vec(
1177+
MCMCChains.group(chain, Symbol(n)).value[sample_idx, :, chain_idx]
1178+
)
1179+
v_true = vec(v)
1180+
else
1181+
chain_val = chain[sample_idx, n, chain_idx]
1182+
v_true = v
1183+
end
1184+
1185+
@test v_true == chain_val
1186+
end
1187+
end
1188+
1189+
"""
1190+
short_varinfo_name(vi::AbstractVarInfo)
1191+
1192+
Return string representing a short description of `vi`.
1193+
"""
1194+
short_varinfo_name(vi::DynamicPPL.ThreadSafeVarInfo) =
1195+
"threadsafe($(short_varinfo_name(vi.varinfo)))"
1196+
function short_varinfo_name(vi::TypedVarInfo)
1197+
DynamicPPL.has_varnamedvector(vi) && return "TypedVarInfo with VarNamedVector"
1198+
return "TypedVarInfo"
1199+
end
1200+
short_varinfo_name(::UntypedVarInfo) = "UntypedVarInfo"
1201+
short_varinfo_name(::DynamicPPL.VectorVarInfo) = "VectorVarInfo"
1202+
short_varinfo_name(::SimpleVarInfo{<:NamedTuple}) = "SimpleVarInfo{<:NamedTuple}"
1203+
short_varinfo_name(::SimpleVarInfo{<:OrderedDict}) = "SimpleVarInfo{<:OrderedDict}"
1204+
function short_varinfo_name(::SimpleVarInfo{<:DynamicPPL.VarNamedVector})
1205+
return "SimpleVarInfo{<:VarNamedVector}"
1206+
end
1207+
1208+
# convenient functions for testing model.jl
1209+
# function to modify the representation of values based on their length
1210+
function modify_value_representation(nt::NamedTuple)
1211+
modified_nt = NamedTuple()
1212+
for (key, value) in zip(keys(nt), values(nt))
1213+
if length(value) == 1 # Scalar value
1214+
modified_value = value[1]
1215+
else # Non-scalar value
1216+
modified_value = value
1217+
end
1218+
modified_nt = merge(modified_nt, (key => modified_value,))
1219+
end
1220+
return modified_nt
11001221
end
1222+
1223+
end # module TestExtUtils

test/ad.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
@testset "AD: ForwardDiff and ReverseDiff" begin
2-
@testset "$(m.f)" for m in DynamicPPL.TestUtils.DEMO_MODELS
2+
@testset "$(m.f)" for m in TU.DEMO_MODELS
33
f = DynamicPPL.LogDensityFunction(m)
4-
rand_param_values = DynamicPPL.TestUtils.rand_prior_true(m)
5-
vns = DynamicPPL.TestUtils.varnames(m)
6-
varinfos = DynamicPPL.TestUtils.setup_varinfos(m, rand_param_values, vns)
4+
rand_param_values = TU.rand_prior_true(m)
5+
vns = TU.varnames(m)
6+
varinfos = TU.setup_varinfos(m, rand_param_values, vns)
77

8-
@testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos
8+
@testset "$(TU.short_varinfo_name(varinfo))" for varinfo in varinfos
99
f = DynamicPPL.LogDensityFunction(m, varinfo)
1010

1111
# use ForwardDiff result as reference

test/compat/ad.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
logpdf(dist, 2.0)
1313
end
1414

15-
test_model_ad(gdemo_default, logp_gdemo_default)
15+
TU.test_model_ad(TU.gdemo_default, logp_gdemo_default)
1616

1717
@model function wishart_ad()
1818
return v ~ Wishart(7, [1 0.5; 0.5 1])
@@ -24,7 +24,7 @@
2424
return logpdf(dist, reshape(x, 2, 2))
2525
end
2626

27-
test_model_ad(wishart_ad(), logp_wishart_ad)
27+
TU.test_model_ad(wishart_ad(), logp_wishart_ad)
2828
end
2929

3030
# https://github.com/TuringLang/Turing.jl/issues/1595

test/contexts.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ end
168168

169169
# Let's check elementwise.
170170
for vn_child in
171-
DynamicPPL.TestUtils.varname_leaves(vn_without_prefix, val)
171+
TU.varname_leaves(vn_without_prefix, val)
172172
if getoptic(vn_child)(val) === missing
173173
@test contextual_isassumption(context, vn_child)
174174
else
@@ -201,7 +201,7 @@ end
201201
vn_without_prefix = remove_prefix(vn)
202202

203203
for vn_child in
204-
DynamicPPL.TestUtils.varname_leaves(vn_without_prefix, val)
204+
TU.varname_leaves(vn_without_prefix, val)
205205
# `vn_child` should be in `context`.
206206
@test hasconditioned_nested(context, vn_child)
207207
# Value should be the same as extracted above.
@@ -216,7 +216,7 @@ end
216216
@testset "Evaluation" begin
217217
@testset "$context" for context in contexts
218218
# Just making sure that we can actually sample with each of the contexts.
219-
@test (gdemo_default(SamplingContext(context)); true)
219+
@test (TU.gdemo_default(SamplingContext(context)); true)
220220
end
221221
end
222222

@@ -258,7 +258,7 @@ end
258258
end
259259

260260
@testset "FixedContext" begin
261-
@testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS
261+
@testset "$(model.f)" for model in TU.DEMO_MODELS
262262
retval = model()
263263
s, m = retval.s, retval.m
264264

test/debug_utils.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,19 @@
11
@testset "check_model" begin
22
@testset "context interface" begin
33
# HACK: Require a model to instantiate it, so let's just grab one.
4-
model = first(DynamicPPL.TestUtils.DEMO_MODELS)
4+
model = first(TU.DEMO_MODELS)
55
context = DynamicPPL.DebugUtils.DebugContext(model)
6-
DynamicPPL.TestUtils.test_context_interface(context)
6+
TU.test_context_interface(context)
77
end
88

9-
@testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS
9+
@testset "$(model.f)" for model in TU.DEMO_MODELS
1010
issuccess, trace = check_model_and_trace(model)
1111
# These models should all work.
1212
@test issuccess
1313

1414
# Check that the trace contains all the variables in the model.
1515
varnames_in_trace = DynamicPPL.DebugUtils.varnames_in_trace(trace)
16-
for vn in DynamicPPL.TestUtils.varnames(model)
16+
for vn in TU.varnames(model)
1717
@test vn in varnames_in_trace
1818
end
1919

@@ -156,7 +156,7 @@
156156
end
157157

158158
@testset "comparing multiple traces" begin
159-
model = DynamicPPL.TestUtils.demo_dynamic_constraint()
159+
model = TU.demo_dynamic_constraint()
160160
issuccess_1, trace_1 = check_model_and_trace(model)
161161
issuccess_2, trace_2 = check_model_and_trace(model)
162162
@test issuccess_1 && issuccess_2

test/linking.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -75,8 +75,8 @@ end
7575
model = demo()
7676

7777
example_values = rand(NamedTuple, model)
78-
vis = DynamicPPL.TestUtils.setup_varinfos(model, example_values, (@varname(m),))
79-
@testset "$(short_varinfo_name(vi))" for vi in vis
78+
vis = TU.setup_varinfos(model, example_values, (@varname(m),))
79+
@testset "$(TU.short_varinfo_name(vi))" for vi in vis
8080
# Evaluate once to ensure we have `logp` value.
8181
vi = last(DynamicPPL.evaluate!!(model, vi, DefaultContext()))
8282
vi_linked = if mutable
@@ -109,10 +109,10 @@ end
109109
model = demo_lkj(d)
110110
dist = LKJCholesky(d, 1.0, uplo)
111111
values_original = rand(NamedTuple, model)
112-
vis = DynamicPPL.TestUtils.setup_varinfos(
112+
vis = TU.setup_varinfos(
113113
model, values_original, (@varname(x),)
114114
)
115-
@testset "$(short_varinfo_name(vi))" for vi in vis
115+
@testset "$(TU.short_varinfo_name(vi))" for vi in vis
116116
val = vi[@varname(x), dist]
117117
# Ensure that `reconstruct` works as intended.
118118
@test val isa Cholesky
@@ -150,8 +150,8 @@ end
150150
@testset "d=$d" for d in [2, 3, 5]
151151
model = demo_dirichlet(d)
152152
example_values = rand(NamedTuple, model)
153-
vis = DynamicPPL.TestUtils.setup_varinfos(model, example_values, (@varname(x),))
154-
@testset "$(short_varinfo_name(vi))" for vi in vis
153+
vis = TU.setup_varinfos(model, example_values, (@varname(x),))
154+
@testset "$(TU.short_varinfo_name(vi))" for vi in vis
155155
lp = logpdf(Dirichlet(d, 1.0), vi[:])
156156
@test length(vi[:]) == d
157157
lp_model = logjoint(model, vi)
@@ -189,8 +189,8 @@ end
189189
]
190190
model = demo_highdim_dirichlet(ns...)
191191
example_values = rand(NamedTuple, model)
192-
vis = DynamicPPL.TestUtils.setup_varinfos(model, example_values, (@varname(x),))
193-
@testset "$(short_varinfo_name(vi))" for vi in vis
192+
vis = TU.setup_varinfos(model, example_values, (@varname(x),))
193+
@testset "$(TU.short_varinfo_name(vi))" for vi in vis
194194
# Linked.
195195
vi_linked = if mutable
196196
DynamicPPL.link!!(deepcopy(vi), model)

test/logdensityfunction.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
using Test, DynamicPPL, ADTypes, LogDensityProblems, LogDensityProblemsAD, ReverseDiff
22

33
@testset "`getmodel` and `setmodel`" begin
4-
@testset "$(nameof(model))" for model in DynamicPPL.TestUtils.DEMO_MODELS
5-
model = DynamicPPL.TestUtils.DEMO_MODELS[1]
4+
@testset "$(nameof(model))" for model in TU.DEMO_MODELS
5+
model = TU.DEMO_MODELS[1]
66
= DynamicPPL.LogDensityFunction(model)
77
@test DynamicPPL.getmodel(ℓ) == model
88
@test DynamicPPL.setmodel(ℓ, model).model == model
@@ -21,10 +21,10 @@ using Test, DynamicPPL, ADTypes, LogDensityProblems, LogDensityProblemsAD, Rever
2121
end
2222

2323
@testset "LogDensityFunction" begin
24-
@testset "$(nameof(model))" for model in DynamicPPL.TestUtils.DEMO_MODELS
25-
example_values = DynamicPPL.TestUtils.rand_prior_true(model)
26-
vns = DynamicPPL.TestUtils.varnames(model)
27-
varinfos = DynamicPPL.TestUtils.setup_varinfos(model, example_values, vns)
24+
@testset "$(nameof(model))" for model in TU.DEMO_MODELS
25+
example_values = TU.rand_prior_true(model)
26+
vns = TU.varnames(model)
27+
varinfos = TU.setup_varinfos(model, example_values, vns)
2828

2929
@testset "$(varinfo)" for varinfo in varinfos
3030
logdensity = DynamicPPL.LogDensityFunction(model, varinfo)

0 commit comments

Comments
 (0)