Skip to content

Commit 4606bdd

Browse files
committed
Add unit tests for HMC InferenceAlgorithm interface
1 parent 6923a72 commit 4606bdd

File tree

4 files changed

+253
-8
lines changed

4 files changed

+253
-8
lines changed

src/mcmc/hmc.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,11 @@ function update_sample_kwargs(alg::AdaptiveHamiltonian, N::Integer, kwargs)
3333
_discard_initial = discard_initial
3434
end
3535

36-
(nadapts=_nadapts, discard_initial=_discard_initial, kwargs...)
36+
# Have to put kwargs first so that the later keyword arguments
37+
# override anything that's already inside it.
38+
(kwargs..., nadapts=_nadapts, discard_initial=_discard_initial)
3739
else
38-
(nadapts=0, discard_adapt=false, discard_initial=0, kwargs...)
40+
(kwargs..., nadapts=0, discard_adapt=false, discard_initial=0)
3941
end
4042
end
4143

test/ext/dynamichmc.jl

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,37 @@ using StableRNGs: StableRNG
1212
using Turing
1313

1414
@testset "TuringDynamicHMCExt" begin
15-
@test DynamicPPL.alg_str(Sampler(externalsampler(DynamicHMC.NUTS()))) == "DynamicNUTS"
16-
17-
rng = StableRNG(468)
1815
spl = externalsampler(DynamicHMC.NUTS())
19-
chn = sample(rng, gdemo_default, spl, 10_000)
20-
check_gdemo(chn)
16+
17+
@testset "alg_str" begin
18+
@test DynamicPPL.alg_str(Sampler(spl)) == "DynamicNUTS"
19+
end
20+
21+
@testset "sample() interface" begin
22+
@model function demo_normal(x)
23+
a ~ Normal()
24+
return x ~ Normal(a)
25+
end
26+
model = demo_normal(2.0)
27+
# note: passing LDF to a Hamiltonian sampler requires explicit adtype
28+
ldf = LogDensityFunction(model; adtype=AutoForwardDiff())
29+
sampling_objects = Dict("DynamicPPL.Model" => model, "LogDensityFunction" => ldf)
30+
seed = 468
31+
@testset "sampling with $name" for (name, model_or_ldf) in sampling_objects
32+
# check sampling works without rng
33+
@test sample(model_or_ldf, spl, 5) isa Chains
34+
# check reproducibility with rng
35+
chn1 = sample(Random.Xoshiro(seed), model_or_ldf, spl, 5)
36+
chn2 = sample(Random.Xoshiro(seed), model_or_ldf, spl, 5)
37+
@test mean(chn1[:a]) == mean(chn2[:a])
38+
end
39+
end
40+
41+
@testset "numerical accuracy" begin
42+
rng = StableRNG(468)
43+
chn = sample(rng, gdemo_default, spl, 10_000)
44+
check_gdemo(chn)
45+
end
2146
end
2247

2348
end

test/mcmc/hmc.jl

Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,162 @@ using Turing
2020
@info "Starting HMC tests"
2121
seed = 123
2222

23+
@testset "InferenceAlgorithm interface" begin
24+
# Check that the various Hamiltonian samplers implement the
25+
# Turing.Inference.InferenceAlgorithm interface correctly.
26+
algs = [HMC(0.1, 3), HMCDA(0.8, 0.75), NUTS(0.5), NUTS(0, 0.5)]
27+
28+
@testset "get_adtype" begin
29+
# Default
30+
for alg in algs
31+
@test Turing.Inference.get_adtype(alg) == Turing.DEFAULT_ADTYPE
32+
end
33+
# Manual
34+
for adtype in (AutoReverseDiff(), AutoMooncake(; config=nothing))
35+
alg1 = HMC(0.1, 3; adtype=adtype)
36+
alg2 = HMCDA(0.8, 0.75; adtype=adtype)
37+
alg3 = NUTS(0.5; adtype=adtype)
38+
@test Turing.Inference.get_adtype(alg1) == adtype
39+
@test Turing.Inference.get_adtype(alg2) == adtype
40+
@test Turing.Inference.get_adtype(alg3) == adtype
41+
end
42+
end
43+
44+
@testset "requires_unconstrained_space" begin
45+
# Hamiltonian samplers always need it
46+
for alg in algs
47+
@test Turing.Inference.requires_unconstrained_space(alg)
48+
end
49+
end
50+
51+
@testset "update_sample_kwargs" begin
52+
# Static Hamiltonian
53+
static_alg = HMC(0.1, 3)
54+
# Adaptive Hamiltonian, where the number of adaptations is
55+
# explicitly specified (here 200)
56+
adaptive_alg_explicit_nadapts = HMCDA(200, 0.8, 0.75)
57+
# Adaptive Hamiltonian, where the number of adaptations is
58+
# implicit
59+
adaptive_alg_implicit_nadapts = NUTS(0.5)
60+
61+
# chain length
62+
N = 1000
63+
64+
# convenience function to check NamedTuple equality up to ordering, i.e.
65+
# we want (a=1, b=2) to be equal to (b=2, a=1)
66+
nt_eq(nt1, nt2) = Dict(pairs(nt1)) == Dict(pairs(nt2))
67+
68+
# We don't test every single possibility of keyword arguments here,
69+
# just some typical cases that reflect common usage.
70+
71+
# Case 1: no relevant kwargs. The adaptive algorithms need to add
72+
# in the number of adaptations and set discard_initial equal to
73+
# that. The static algorithm does not need to do anything.
74+
kwargs = (; _foo="bar")
75+
@test nt_eq(
76+
Turing.Inference.update_sample_kwargs(static_alg, N, kwargs), kwargs
77+
)
78+
@test nt_eq(
79+
Turing.Inference.update_sample_kwargs(
80+
adaptive_alg_explicit_nadapts, N, kwargs
81+
),
82+
(nadapts=200, discard_initial=200, _foo="bar"),
83+
)
84+
@test nt_eq(
85+
Turing.Inference.update_sample_kwargs(
86+
adaptive_alg_implicit_nadapts, N, kwargs
87+
),
88+
# by default the adaptive algorithm takes N / 2 adaptations, or
89+
# 1000, whichever is smaller. In this case since N = 1000, we
90+
# expect the number of adaptations to be 500.
91+
(nadapts=500, discard_initial=500, _foo="bar"),
92+
)
93+
94+
# Case 2: When resuming from an earlier chain. In this case, no
95+
# adaptation is needed.
96+
chn = Chains([1.0], [:a])
97+
kwargs = (; resume_from=chn)
98+
kwargs_without_adapts = (
99+
nadapts=0, discard_initial=0, discard_adapt=false, resume_from=chn
100+
)
101+
@test nt_eq(
102+
Turing.Inference.update_sample_kwargs(static_alg, N, kwargs), kwargs
103+
)
104+
@test nt_eq(
105+
Turing.Inference.update_sample_kwargs(
106+
adaptive_alg_explicit_nadapts, N, kwargs
107+
),
108+
kwargs_without_adapts,
109+
)
110+
@test nt_eq(
111+
Turing.Inference.update_sample_kwargs(
112+
adaptive_alg_implicit_nadapts, N, kwargs
113+
),
114+
kwargs_without_adapts,
115+
)
116+
117+
# Case 3: user manually specifies number of adaptations.
118+
kwargs = (; nadapts=500)
119+
kwargs_with_adapts = (nadapts=500, discard_initial=500)
120+
@test nt_eq(
121+
Turing.Inference.update_sample_kwargs(static_alg, N, kwargs), kwargs
122+
)
123+
@test nt_eq(
124+
Turing.Inference.update_sample_kwargs(
125+
adaptive_alg_explicit_nadapts, N, kwargs
126+
),
127+
kwargs_with_adapts,
128+
)
129+
@test nt_eq(
130+
Turing.Inference.update_sample_kwargs(
131+
adaptive_alg_implicit_nadapts, N, kwargs
132+
),
133+
kwargs_with_adapts,
134+
)
135+
136+
# Case 4: user wants to keep the adaptations
137+
kwargs = (; discard_adapt=false)
138+
@test nt_eq(
139+
Turing.Inference.update_sample_kwargs(static_alg, N, kwargs), kwargs
140+
)
141+
@test nt_eq(
142+
Turing.Inference.update_sample_kwargs(
143+
adaptive_alg_explicit_nadapts, N, kwargs
144+
),
145+
(nadapts=200, discard_initial=0, discard_adapt=false),
146+
)
147+
@test nt_eq(
148+
Turing.Inference.update_sample_kwargs(
149+
adaptive_alg_implicit_nadapts, N, kwargs
150+
),
151+
(nadapts=500, discard_initial=0, discard_adapt=false),
152+
)
153+
end
154+
end
155+
156+
@testset "sample() interface" begin
157+
@model function demo_normal(x)
158+
a ~ Normal()
159+
return x ~ Normal(a)
160+
end
161+
model = demo_normal(2.0)
162+
# note: passing LDF to a Hamiltonian sampler requires explicit adtype
163+
ldf = LogDensityFunction(model; adtype=AutoForwardDiff())
164+
sampling_objects = Dict("DynamicPPL.Model" => model, "LogDensityFunction" => ldf)
165+
algs = [HMC(0.1, 3), HMCDA(0.8, 0.75), NUTS(0.5)]
166+
seed = 468
167+
@testset "sampling with $name" for (name, model_or_ldf) in sampling_objects
168+
@testset "$alg" for alg in algs
169+
# check sampling works without rng
170+
@test sample(model_or_ldf, alg, 5) isa Chains
171+
# check reproducibility with rng
172+
chn1 = sample(Random.Xoshiro(seed), model_or_ldf, alg, 5)
173+
chn2 = sample(Random.Xoshiro(seed), model_or_ldf, alg, 5)
174+
@test mean(chn1[:a]) == mean(chn2[:a])
175+
end
176+
end
177+
end
178+
23179
@testset "constrained bounded" begin
24180
obs = [0, 1, 0, 1, 1, 1, 1, 1, 1, 1]
25181

test/mcmc/sghmc.jl

Lines changed: 63 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,73 @@ using DynamicPPL: DynamicPPL
88
using Distributions: sample
99
import ForwardDiff
1010
using LinearAlgebra: dot
11-
import ReverseDiff
11+
using Random: Xoshiro
1212
using StableRNGs: StableRNG
1313
using Test: @test, @testset
1414
using Turing
1515

16+
@testset "SGHMC + SGLD: InferenceAlgorithm interface" begin
17+
algs = [
18+
SGHMC(; learning_rate=0.01, momentum_decay=0.1),
19+
SGLD(; stepsize=PolynomialStepsize(0.25)),
20+
]
21+
22+
@testset "get_adtype" begin
23+
# Default
24+
for alg in algs
25+
@test Turing.Inference.get_adtype(alg) == Turing.DEFAULT_ADTYPE
26+
end
27+
# Manual
28+
for adtype in (AutoReverseDiff(), AutoMooncake(; config=nothing))
29+
alg1 = SGHMC(; learning_rate=0.01, momentum_decay=0.1, adtype=adtype)
30+
alg2 = SGLD(; stepsize=PolynomialStepsize(0.25), adtype=adtype)
31+
@test Turing.Inference.get_adtype(alg1) == adtype
32+
@test Turing.Inference.get_adtype(alg2) == adtype
33+
end
34+
end
35+
36+
@testset "requires_unconstrained_space" begin
37+
# Hamiltonian samplers always need it
38+
for alg in algs
39+
@test Turing.Inference.requires_unconstrained_space(alg)
40+
end
41+
end
42+
43+
@testset "update_sample_kwargs" begin
44+
# These don't update kwargs
45+
for alg in algs
46+
kwargs = (a=1, b=2)
47+
@test Turing.Inference.update_sample_kwargs(alg, 1000, kwargs) == kwargs
48+
end
49+
end
50+
end
51+
52+
@testset verbose = true "SGHMC + SGLD: sample() interface" begin
53+
@model function demo_normal(x)
54+
a ~ Normal()
55+
return x ~ Normal(a)
56+
end
57+
model = demo_normal(2.0)
58+
# note: passing LDF to a Hamiltonian sampler requires explicit adtype
59+
ldf = LogDensityFunction(model; adtype=AutoForwardDiff())
60+
sampling_objects = Dict("DynamicPPL.Model" => model, "LogDensityFunction" => ldf)
61+
algs = [
62+
SGHMC(; learning_rate=0.01, momentum_decay=0.1),
63+
SGLD(; stepsize=PolynomialStepsize(0.25)),
64+
]
65+
seed = 468
66+
@testset "sampling with $name" for (name, model_or_ldf) in sampling_objects
67+
@testset "$alg" for alg in algs
68+
# check sampling works without rng
69+
@test sample(model_or_ldf, alg, 5) isa Chains
70+
# check reproducibility with rng
71+
chn1 = sample(Xoshiro(seed), model_or_ldf, alg, 5)
72+
chn2 = sample(Xoshiro(seed), model_or_ldf, alg, 5)
73+
@test mean(chn1[:a]) == mean(chn2[:a])
74+
end
75+
end
76+
end
77+
1678
@testset verbose = true "Testing sghmc.jl" begin
1779
@testset "sghmc constructor" begin
1880
alg = SGHMC(; learning_rate=0.01, momentum_decay=0.1)

0 commit comments

Comments
 (0)