@@ -20,6 +20,162 @@ using Turing
20
20
@info " Starting HMC tests"
21
21
seed = 123
22
22
23
+ @testset " InferenceAlgorithm interface" begin
24
+ # Check that the various Hamiltonian samplers implement the
25
+ # Turing.Inference.InferenceAlgorithm interface correctly.
26
+ algs = [HMC (0.1 , 3 ), HMCDA (0.8 , 0.75 ), NUTS (0.5 ), NUTS (0 , 0.5 )]
27
+
28
+ @testset " get_adtype" begin
29
+ # Default
30
+ for alg in algs
31
+ @test Turing. Inference. get_adtype (alg) == Turing. DEFAULT_ADTYPE
32
+ end
33
+ # Manual
34
+ for adtype in (AutoReverseDiff (), AutoMooncake (; config= nothing ))
35
+ alg1 = HMC (0.1 , 3 ; adtype= adtype)
36
+ alg2 = HMCDA (0.8 , 0.75 ; adtype= adtype)
37
+ alg3 = NUTS (0.5 ; adtype= adtype)
38
+ @test Turing. Inference. get_adtype (alg1) == adtype
39
+ @test Turing. Inference. get_adtype (alg2) == adtype
40
+ @test Turing. Inference. get_adtype (alg3) == adtype
41
+ end
42
+ end
43
+
44
+ @testset " requires_unconstrained_space" begin
45
+ # Hamiltonian samplers always need it
46
+ for alg in algs
47
+ @test Turing. Inference. requires_unconstrained_space (alg)
48
+ end
49
+ end
50
+
51
+ @testset " update_sample_kwargs" begin
52
+ # Static Hamiltonian
53
+ static_alg = HMC (0.1 , 3 )
54
+ # Adaptive Hamiltonian, where the number of adaptations is
55
+ # explicitly specified (here 200)
56
+ adaptive_alg_explicit_nadapts = HMCDA (200 , 0.8 , 0.75 )
57
+ # Adaptive Hamiltonian, where the number of adaptations is
58
+ # implicit
59
+ adaptive_alg_implicit_nadapts = NUTS (0.5 )
60
+
61
+ # chain length
62
+ N = 1000
63
+
64
+ # convenience function to check NamedTuple equality up to ordering, i.e.
65
+ # we want (a=1, b=2) to be equal to (b=2, a=1)
66
+ nt_eq (nt1, nt2) = Dict (pairs (nt1)) == Dict (pairs (nt2))
67
+
68
+ # We don't test every single possibility of keyword arguments here,
69
+ # just some typical cases that reflect common usage.
70
+
71
+ # Case 1: no relevant kwargs. The adaptive algorithms need to add
72
+ # in the number of adaptations and set discard_initial equal to
73
+ # that. The static algorithm does not need to do anything.
74
+ kwargs = (; _foo= " bar" )
75
+ @test nt_eq (
76
+ Turing. Inference. update_sample_kwargs (static_alg, N, kwargs), kwargs
77
+ )
78
+ @test nt_eq (
79
+ Turing. Inference. update_sample_kwargs (
80
+ adaptive_alg_explicit_nadapts, N, kwargs
81
+ ),
82
+ (nadapts= 200 , discard_initial= 200 , _foo= " bar" ),
83
+ )
84
+ @test nt_eq (
85
+ Turing. Inference. update_sample_kwargs (
86
+ adaptive_alg_implicit_nadapts, N, kwargs
87
+ ),
88
+ # by default the adaptive algorithm takes N / 2 adaptations, or
89
+ # 1000, whichever is smaller. In this case since N = 1000, we
90
+ # expect the number of adaptations to be 500.
91
+ (nadapts= 500 , discard_initial= 500 , _foo= " bar" ),
92
+ )
93
+
94
+ # Case 2: When resuming from an earlier chain. In this case, no
95
+ # adaptation is needed.
96
+ chn = Chains ([1.0 ], [:a ])
97
+ kwargs = (; resume_from= chn)
98
+ kwargs_without_adapts = (
99
+ nadapts= 0 , discard_initial= 0 , discard_adapt= false , resume_from= chn
100
+ )
101
+ @test nt_eq (
102
+ Turing. Inference. update_sample_kwargs (static_alg, N, kwargs), kwargs
103
+ )
104
+ @test nt_eq (
105
+ Turing. Inference. update_sample_kwargs (
106
+ adaptive_alg_explicit_nadapts, N, kwargs
107
+ ),
108
+ kwargs_without_adapts,
109
+ )
110
+ @test nt_eq (
111
+ Turing. Inference. update_sample_kwargs (
112
+ adaptive_alg_implicit_nadapts, N, kwargs
113
+ ),
114
+ kwargs_without_adapts,
115
+ )
116
+
117
+ # Case 3: user manually specifies number of adaptations.
118
+ kwargs = (; nadapts= 500 )
119
+ kwargs_with_adapts = (nadapts= 500 , discard_initial= 500 )
120
+ @test nt_eq (
121
+ Turing. Inference. update_sample_kwargs (static_alg, N, kwargs), kwargs
122
+ )
123
+ @test nt_eq (
124
+ Turing. Inference. update_sample_kwargs (
125
+ adaptive_alg_explicit_nadapts, N, kwargs
126
+ ),
127
+ kwargs_with_adapts,
128
+ )
129
+ @test nt_eq (
130
+ Turing. Inference. update_sample_kwargs (
131
+ adaptive_alg_implicit_nadapts, N, kwargs
132
+ ),
133
+ kwargs_with_adapts,
134
+ )
135
+
136
+ # Case 4: user wants to keep the adaptations
137
+ kwargs = (; discard_adapt= false )
138
+ @test nt_eq (
139
+ Turing. Inference. update_sample_kwargs (static_alg, N, kwargs), kwargs
140
+ )
141
+ @test nt_eq (
142
+ Turing. Inference. update_sample_kwargs (
143
+ adaptive_alg_explicit_nadapts, N, kwargs
144
+ ),
145
+ (nadapts= 200 , discard_initial= 0 , discard_adapt= false ),
146
+ )
147
+ @test nt_eq (
148
+ Turing. Inference. update_sample_kwargs (
149
+ adaptive_alg_implicit_nadapts, N, kwargs
150
+ ),
151
+ (nadapts= 500 , discard_initial= 0 , discard_adapt= false ),
152
+ )
153
+ end
154
+ end
155
+
156
+ @testset " sample() interface" begin
157
+ @model function demo_normal (x)
158
+ a ~ Normal ()
159
+ return x ~ Normal (a)
160
+ end
161
+ model = demo_normal (2.0 )
162
+ # note: passing LDF to a Hamiltonian sampler requires explicit adtype
163
+ ldf = LogDensityFunction (model; adtype= AutoForwardDiff ())
164
+ sampling_objects = Dict (" DynamicPPL.Model" => model, " LogDensityFunction" => ldf)
165
+ algs = [HMC (0.1 , 3 ), HMCDA (0.8 , 0.75 ), NUTS (0.5 )]
166
+ seed = 468
167
+ @testset " sampling with $name " for (name, model_or_ldf) in sampling_objects
168
+ @testset " $alg " for alg in algs
169
+ # check sampling works without rng
170
+ @test sample (model_or_ldf, alg, 5 ) isa Chains
171
+ # check reproducibility with rng
172
+ chn1 = sample (Random. Xoshiro (seed), model_or_ldf, alg, 5 )
173
+ chn2 = sample (Random. Xoshiro (seed), model_or_ldf, alg, 5 )
174
+ @test mean (chn1[:a ]) == mean (chn2[:a ])
175
+ end
176
+ end
177
+ end
178
+
23
179
@testset " constrained bounded" begin
24
180
obs = [0 , 1 , 0 , 1 , 1 , 1 , 1 , 1 , 1 , 1 ]
25
181
0 commit comments