@@ -4,7 +4,14 @@ using ADTypes: AbstractADType, AutoForwardDiff
4
4
using Chairmarks: @be
5
5
import DifferentiationInterface as DI
6
6
using DocStringExtensions
7
- using DynamicPPL: Model, LogDensityFunction, VarInfo, AbstractVarInfo, link
7
+ using DynamicPPL:
8
+ Model,
9
+ LogDensityFunction,
10
+ VarInfo,
11
+ AbstractVarInfo,
12
+ link,
13
+ DefaultContext,
14
+ AbstractContext
8
15
using LogDensityProblems: logdensity, logdensity_and_gradient
9
16
using Random: Random, Xoshiro
10
17
using Statistics: median
@@ -53,6 +60,8 @@ struct ADResult{Tparams<:AbstractFloat,Tresult<:AbstractFloat}
53
60
model:: Model
54
61
" The VarInfo that was used"
55
62
varinfo:: AbstractVarInfo
63
+ " The evaluation context that was used"
64
+ context:: AbstractContext
56
65
" The values at which the model was evaluated"
57
66
params:: Vector{Tparams}
58
67
" The AD backend that was tested"
83
92
grad_atol=1e-6,
84
93
varinfo::AbstractVarInfo=link(VarInfo(model), model),
85
94
params::Union{Nothing,Vector{<:AbstractFloat}}=nothing,
95
+ context::AbstractContext=DefaultContext(),
86
96
reference_adtype::ADTypes.AbstractADType=REFERENCE_ADTYPE,
87
97
expected_value_and_grad::Union{Nothing,Tuple{AbstractFloat,Vector{<:AbstractFloat}}}=nothing,
88
98
verbose=true,
@@ -136,7 +146,13 @@ Everything else is optional, and can be categorised into several groups:
136
146
prep_params)`. You could then evaluate the gradient at a different set of
137
147
parameters using the `params` keyword argument.
138
148
139
- 3. _How to specify the results to compare against._ (Only if `test=true`.)
149
+ 3. _How to specify the evaluation context._
150
+
151
+ A `DynamicPPL.AbstractContext` can be passed as the `context` keyword
152
+ argument to control the evaluation context. This defaults to
153
+ `DefaultContext()`.
154
+
155
+ 4. _How to specify the results to compare against._ (Only if `test=true`.)
140
156
141
157
Once logp and its gradient has been calculated with the specified `adtype`,
142
158
it must be tested for correctness.
@@ -151,12 +167,12 @@ Everything else is optional, and can be categorised into several groups:
151
167
The default reference backend is ForwardDiff. If none of these parameters are
152
168
specified, ForwardDiff will be used to calculate the ground truth.
153
169
154
- 4 . _How to specify the tolerances._ (Only if `test=true`.)
170
+ 5 . _How to specify the tolerances._ (Only if `test=true`.)
155
171
156
172
The tolerances for the value and gradient can be set using `value_atol` and
157
173
`grad_atol`. These default to 1e-6.
158
174
159
- 5 . _Whether to output extra logging information._
175
+ 6 . _Whether to output extra logging information._
160
176
161
177
By default, this function prints messages when it runs. To silence it, set
162
178
`verbose=false`.
@@ -179,6 +195,7 @@ function run_ad(
179
195
grad_atol:: AbstractFloat = 1e-6 ,
180
196
varinfo:: AbstractVarInfo = link (VarInfo (model), model),
181
197
params:: Union{Nothing,Vector{<:AbstractFloat}} = nothing ,
198
+ context:: AbstractContext = DefaultContext (),
182
199
reference_adtype:: AbstractADType = REFERENCE_ADTYPE,
183
200
expected_value_and_grad:: Union{Nothing,Tuple{AbstractFloat,Vector{<:AbstractFloat}}} = nothing ,
184
201
verbose= true ,
@@ -190,7 +207,7 @@ function run_ad(
190
207
191
208
verbose && @info " Running AD on $(model. f) with $(adtype) \n "
192
209
verbose && println (" params : $(params) " )
193
- ldf = LogDensityFunction (model, varinfo; adtype= adtype)
210
+ ldf = LogDensityFunction (model, varinfo, context ; adtype= adtype)
194
211
195
212
value, grad = logdensity_and_gradient (ldf, params)
196
213
grad = collect (grad)
@@ -199,7 +216,7 @@ function run_ad(
199
216
if test
200
217
# Calculate ground truth to compare against
201
218
value_true, grad_true = if expected_value_and_grad === nothing
202
- ldf_reference = LogDensityFunction (model, varinfo; adtype= reference_adtype)
219
+ ldf_reference = LogDensityFunction (model, varinfo, context ; adtype= reference_adtype)
203
220
logdensity_and_gradient (ldf_reference, params)
204
221
else
205
222
expected_value_and_grad
@@ -228,6 +245,7 @@ function run_ad(
228
245
return ADResult (
229
246
model,
230
247
varinfo,
248
+ context,
231
249
params,
232
250
adtype,
233
251
value_atol,
0 commit comments