Skip to content

Commit ba490bf

Browse files
committed
Revert "Move src/test_utils and test/test_util to DynamicPPLTestExt"
This reverts commit dcd24e7.
1 parent dcd24e7 commit ba490bf

21 files changed

+273
-287
lines changed

Project.toml

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@ EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
3131
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
3232
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
3333
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
34-
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
3534
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
3635

3736
[extensions]
@@ -40,7 +39,6 @@ DynamicPPLEnzymeCoreExt = ["EnzymeCore"]
4039
DynamicPPLForwardDiffExt = ["ForwardDiff"]
4140
DynamicPPLMCMCChainsExt = ["MCMCChains"]
4241
DynamicPPLReverseDiffExt = ["ReverseDiff"]
43-
DynamicPPLTestExt = ["Test"]
4442
DynamicPPLZygoteRulesExt = ["ZygoteRules"]
4543

4644
[compat]
@@ -69,3 +67,11 @@ ReverseDiff = "1"
6967
Test = "1.6"
7068
ZygoteRules = "0.2"
7169
julia = "1.10"
70+
71+
[extras]
72+
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
73+
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
74+
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
75+
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
76+
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
77+
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"

ext/DynamicPPLTestExt.jl

Lines changed: 0 additions & 8 deletions
This file was deleted.

ext/DynamicPPLTestExt/utils.jl renamed to src/test_utils.jl

Lines changed: 1 addition & 124 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,4 @@
1-
module TestExtUtils
2-
3-
###################################################
4-
# These used to be in DPPL/src/test_utils.jl ######
5-
###################################################
1+
module TestUtils
62

73
using AbstractMCMC
84
using DynamicPPL
@@ -1101,123 +1097,4 @@ function DynamicPPL.dot_tilde_observe(
11011097
return logp * context.mod, vi
11021098
end
11031099

1104-
1105-
1106-
###################################################
1107-
# These used to be in DPPL/test/test_util.jl ######
1108-
###################################################
1109-
1110-
# default model
1111-
@model function gdemo_d()
1112-
s ~ InverseGamma(2, 3)
1113-
m ~ Normal(0, sqrt(s))
1114-
1.5 ~ Normal(m, sqrt(s))
1115-
2.0 ~ Normal(m, sqrt(s))
1116-
return s, m
1117-
end
1118-
const gdemo_default = gdemo_d()
1119-
1120-
function test_model_ad(model, logp_manual)
1121-
vi = VarInfo(model)
1122-
x = DynamicPPL.getall(vi)
1123-
1124-
# Log probabilities using the model.
1125-
= DynamicPPL.LogDensityFunction(model, vi)
1126-
logp_model = Base.Fix1(LogDensityProblems.logdensity, ℓ)
1127-
1128-
# Check that both functions return the same values.
1129-
lp = logp_manual(x)
1130-
@test logp_model(x) lp
1131-
1132-
# Gradients based on the manual implementation.
1133-
grad = ForwardDiff.gradient(logp_manual, x)
1134-
1135-
y, back = Tracker.forward(logp_manual, x)
1136-
@test Tracker.data(y) lp
1137-
@test Tracker.data(back(1)[1]) grad
1138-
1139-
y, back = Zygote.pullback(logp_manual, x)
1140-
@test y lp
1141-
@test back(1)[1] grad
1142-
1143-
# Gradients based on the model.
1144-
@test ForwardDiff.gradient(logp_model, x) grad
1145-
1146-
y, back = Tracker.forward(logp_model, x)
1147-
@test Tracker.data(y) lp
1148-
@test Tracker.data(back(1)[1]) grad
1149-
1150-
y, back = Zygote.pullback(logp_model, x)
1151-
@test y lp
1152-
@test back(1)[1] grad
1153-
end
1154-
1155-
"""
1156-
test_setval!(model, chain; sample_idx = 1, chain_idx = 1)
1157-
1158-
Test `setval!` on `model` and `chain`.
1159-
1160-
Worth noting that this only supports models containing symbols of the forms
1161-
`m`, `m[1]`, `m[1, 2]`, not `m[1][1]`, etc.
1162-
"""
1163-
function test_setval!(model, chain; sample_idx=1, chain_idx=1)
1164-
var_info = VarInfo(model)
1165-
spl = SampleFromPrior()
1166-
θ_old = var_info[spl]
1167-
DynamicPPL.setval!(var_info, chain, sample_idx, chain_idx)
1168-
θ_new = var_info[spl]
1169-
@test θ_old != θ_new
1170-
vals = DynamicPPL.values_as(var_info, OrderedDict)
1171-
iters = map(DynamicPPL.varname_and_value_leaves, keys(vals), values(vals))
1172-
for (n, v) in mapreduce(collect, vcat, iters)
1173-
n = string(n)
1174-
if Symbol(n) keys(chain)
1175-
# Assume it's a group
1176-
chain_val = vec(
1177-
MCMCChains.group(chain, Symbol(n)).value[sample_idx, :, chain_idx]
1178-
)
1179-
v_true = vec(v)
1180-
else
1181-
chain_val = chain[sample_idx, n, chain_idx]
1182-
v_true = v
1183-
end
1184-
1185-
@test v_true == chain_val
1186-
end
1187-
end
1188-
1189-
"""
1190-
short_varinfo_name(vi::AbstractVarInfo)
1191-
1192-
Return string representing a short description of `vi`.
1193-
"""
1194-
short_varinfo_name(vi::DynamicPPL.ThreadSafeVarInfo) =
1195-
"threadsafe($(short_varinfo_name(vi.varinfo)))"
1196-
function short_varinfo_name(vi::TypedVarInfo)
1197-
DynamicPPL.has_varnamedvector(vi) && return "TypedVarInfo with VarNamedVector"
1198-
return "TypedVarInfo"
1199-
end
1200-
short_varinfo_name(::UntypedVarInfo) = "UntypedVarInfo"
1201-
short_varinfo_name(::DynamicPPL.VectorVarInfo) = "VectorVarInfo"
1202-
short_varinfo_name(::SimpleVarInfo{<:NamedTuple}) = "SimpleVarInfo{<:NamedTuple}"
1203-
short_varinfo_name(::SimpleVarInfo{<:OrderedDict}) = "SimpleVarInfo{<:OrderedDict}"
1204-
function short_varinfo_name(::SimpleVarInfo{<:DynamicPPL.VarNamedVector})
1205-
return "SimpleVarInfo{<:VarNamedVector}"
1206-
end
1207-
1208-
# convenient functions for testing model.jl
1209-
# function to modify the representation of values based on their length
1210-
function modify_value_representation(nt::NamedTuple)
1211-
modified_nt = NamedTuple()
1212-
for (key, value) in zip(keys(nt), values(nt))
1213-
if length(value) == 1 # Scalar value
1214-
modified_value = value[1]
1215-
else # Non-scalar value
1216-
modified_value = value
1217-
end
1218-
modified_nt = merge(modified_nt, (key => modified_value,))
1219-
end
1220-
return modified_nt
12211100
end
1222-
1223-
end # module TestExtUtils

test/ad.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
@testset "AD: ForwardDiff and ReverseDiff" begin
2-
@testset "$(m.f)" for m in TU.DEMO_MODELS
2+
@testset "$(m.f)" for m in DynamicPPL.TestUtils.DEMO_MODELS
33
f = DynamicPPL.LogDensityFunction(m)
4-
rand_param_values = TU.rand_prior_true(m)
5-
vns = TU.varnames(m)
6-
varinfos = TU.setup_varinfos(m, rand_param_values, vns)
4+
rand_param_values = DynamicPPL.TestUtils.rand_prior_true(m)
5+
vns = DynamicPPL.TestUtils.varnames(m)
6+
varinfos = DynamicPPL.TestUtils.setup_varinfos(m, rand_param_values, vns)
77

8-
@testset "$(TU.short_varinfo_name(varinfo))" for varinfo in varinfos
8+
@testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos
99
f = DynamicPPL.LogDensityFunction(m, varinfo)
1010

1111
# use ForwardDiff result as reference

test/compat/ad.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
logpdf(dist, 2.0)
1313
end
1414

15-
TU.test_model_ad(TU.gdemo_default, logp_gdemo_default)
15+
test_model_ad(gdemo_default, logp_gdemo_default)
1616

1717
@model function wishart_ad()
1818
return v ~ Wishart(7, [1 0.5; 0.5 1])
@@ -24,7 +24,7 @@
2424
return logpdf(dist, reshape(x, 2, 2))
2525
end
2626

27-
TU.test_model_ad(wishart_ad(), logp_wishart_ad)
27+
test_model_ad(wishart_ad(), logp_wishart_ad)
2828
end
2929

3030
# https://github.com/TuringLang/Turing.jl/issues/1595

test/contexts.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ end
168168

169169
# Let's check elementwise.
170170
for vn_child in
171-
TU.varname_leaves(vn_without_prefix, val)
171+
DynamicPPL.TestUtils.varname_leaves(vn_without_prefix, val)
172172
if getoptic(vn_child)(val) === missing
173173
@test contextual_isassumption(context, vn_child)
174174
else
@@ -201,7 +201,7 @@ end
201201
vn_without_prefix = remove_prefix(vn)
202202

203203
for vn_child in
204-
TU.varname_leaves(vn_without_prefix, val)
204+
DynamicPPL.TestUtils.varname_leaves(vn_without_prefix, val)
205205
# `vn_child` should be in `context`.
206206
@test hasconditioned_nested(context, vn_child)
207207
# Value should be the same as extracted above.
@@ -216,7 +216,7 @@ end
216216
@testset "Evaluation" begin
217217
@testset "$context" for context in contexts
218218
# Just making sure that we can actually sample with each of the contexts.
219-
@test (TU.gdemo_default(SamplingContext(context)); true)
219+
@test (gdemo_default(SamplingContext(context)); true)
220220
end
221221
end
222222

@@ -258,7 +258,7 @@ end
258258
end
259259

260260
@testset "FixedContext" begin
261-
@testset "$(model.f)" for model in TU.DEMO_MODELS
261+
@testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS
262262
retval = model()
263263
s, m = retval.s, retval.m
264264

test/debug_utils.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,19 @@
11
@testset "check_model" begin
22
@testset "context interface" begin
33
# HACK: Require a model to instantiate it, so let's just grab one.
4-
model = first(TU.DEMO_MODELS)
4+
model = first(DynamicPPL.TestUtils.DEMO_MODELS)
55
context = DynamicPPL.DebugUtils.DebugContext(model)
6-
TU.test_context_interface(context)
6+
DynamicPPL.TestUtils.test_context_interface(context)
77
end
88

9-
@testset "$(model.f)" for model in TU.DEMO_MODELS
9+
@testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS
1010
issuccess, trace = check_model_and_trace(model)
1111
# These models should all work.
1212
@test issuccess
1313

1414
# Check that the trace contains all the variables in the model.
1515
varnames_in_trace = DynamicPPL.DebugUtils.varnames_in_trace(trace)
16-
for vn in TU.varnames(model)
16+
for vn in DynamicPPL.TestUtils.varnames(model)
1717
@test vn in varnames_in_trace
1818
end
1919

@@ -156,7 +156,7 @@
156156
end
157157

158158
@testset "comparing multiple traces" begin
159-
model = TU.demo_dynamic_constraint()
159+
model = DynamicPPL.TestUtils.demo_dynamic_constraint()
160160
issuccess_1, trace_1 = check_model_and_trace(model)
161161
issuccess_2, trace_2 = check_model_and_trace(model)
162162
@test issuccess_1 && issuccess_2

test/linking.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -75,8 +75,8 @@ end
7575
model = demo()
7676

7777
example_values = rand(NamedTuple, model)
78-
vis = TU.setup_varinfos(model, example_values, (@varname(m),))
79-
@testset "$(TU.short_varinfo_name(vi))" for vi in vis
78+
vis = DynamicPPL.TestUtils.setup_varinfos(model, example_values, (@varname(m),))
79+
@testset "$(short_varinfo_name(vi))" for vi in vis
8080
# Evaluate once to ensure we have `logp` value.
8181
vi = last(DynamicPPL.evaluate!!(model, vi, DefaultContext()))
8282
vi_linked = if mutable
@@ -109,10 +109,10 @@ end
109109
model = demo_lkj(d)
110110
dist = LKJCholesky(d, 1.0, uplo)
111111
values_original = rand(NamedTuple, model)
112-
vis = TU.setup_varinfos(
112+
vis = DynamicPPL.TestUtils.setup_varinfos(
113113
model, values_original, (@varname(x),)
114114
)
115-
@testset "$(TU.short_varinfo_name(vi))" for vi in vis
115+
@testset "$(short_varinfo_name(vi))" for vi in vis
116116
val = vi[@varname(x), dist]
117117
# Ensure that `reconstruct` works as intended.
118118
@test val isa Cholesky
@@ -150,8 +150,8 @@ end
150150
@testset "d=$d" for d in [2, 3, 5]
151151
model = demo_dirichlet(d)
152152
example_values = rand(NamedTuple, model)
153-
vis = TU.setup_varinfos(model, example_values, (@varname(x),))
154-
@testset "$(TU.short_varinfo_name(vi))" for vi in vis
153+
vis = DynamicPPL.TestUtils.setup_varinfos(model, example_values, (@varname(x),))
154+
@testset "$(short_varinfo_name(vi))" for vi in vis
155155
lp = logpdf(Dirichlet(d, 1.0), vi[:])
156156
@test length(vi[:]) == d
157157
lp_model = logjoint(model, vi)
@@ -189,8 +189,8 @@ end
189189
]
190190
model = demo_highdim_dirichlet(ns...)
191191
example_values = rand(NamedTuple, model)
192-
vis = TU.setup_varinfos(model, example_values, (@varname(x),))
193-
@testset "$(TU.short_varinfo_name(vi))" for vi in vis
192+
vis = DynamicPPL.TestUtils.setup_varinfos(model, example_values, (@varname(x),))
193+
@testset "$(short_varinfo_name(vi))" for vi in vis
194194
# Linked.
195195
vi_linked = if mutable
196196
DynamicPPL.link!!(deepcopy(vi), model)

test/logdensityfunction.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
using Test, DynamicPPL, ADTypes, LogDensityProblems, LogDensityProblemsAD, ReverseDiff
22

33
@testset "`getmodel` and `setmodel`" begin
4-
@testset "$(nameof(model))" for model in TU.DEMO_MODELS
5-
model = TU.DEMO_MODELS[1]
4+
@testset "$(nameof(model))" for model in DynamicPPL.TestUtils.DEMO_MODELS
5+
model = DynamicPPL.TestUtils.DEMO_MODELS[1]
66
= DynamicPPL.LogDensityFunction(model)
77
@test DynamicPPL.getmodel(ℓ) == model
88
@test DynamicPPL.setmodel(ℓ, model).model == model
@@ -21,10 +21,10 @@ using Test, DynamicPPL, ADTypes, LogDensityProblems, LogDensityProblemsAD, Rever
2121
end
2222

2323
@testset "LogDensityFunction" begin
24-
@testset "$(nameof(model))" for model in TU.DEMO_MODELS
25-
example_values = TU.rand_prior_true(model)
26-
vns = TU.varnames(model)
27-
varinfos = TU.setup_varinfos(model, example_values, vns)
24+
@testset "$(nameof(model))" for model in DynamicPPL.TestUtils.DEMO_MODELS
25+
example_values = DynamicPPL.TestUtils.rand_prior_true(model)
26+
vns = DynamicPPL.TestUtils.varnames(model)
27+
varinfos = DynamicPPL.TestUtils.setup_varinfos(model, example_values, vns)
2828

2929
@testset "$(varinfo)" for varinfo in varinfos
3030
logdensity = DynamicPPL.LogDensityFunction(model, varinfo)

0 commit comments

Comments
 (0)