Skip to content

Commit 2fd1897

Browse files
authored
Allow specifying context in AD testing (#935)
* Allow specifying context in AD testing * Fix ADResult constructor
1 parent 0b6e364 commit 2fd1897

File tree

3 files changed

+29
-7
lines changed

3 files changed

+29
-7
lines changed

HISTORY.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
# DynamicPPL Changelog
22

3+
## 0.36.6
4+
5+
`DynamicPPL.TestUtils.run_ad` now takes an extra `context` keyword argument, which is passed to the `LogDensityFunction` constructor.
6+
37
## 0.36.5
48

59
`varinfo[:]` now returns an empty vector if `varinfo::DynamicPPL.NTVarInfo` is empty, rather than erroring.

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "DynamicPPL"
22
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
3-
version = "0.36.5"
3+
version = "0.36.6"
44

55
[deps]
66
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

src/test_utils/ad.jl

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,14 @@ using ADTypes: AbstractADType, AutoForwardDiff
44
using Chairmarks: @be
55
import DifferentiationInterface as DI
66
using DocStringExtensions
7-
using DynamicPPL: Model, LogDensityFunction, VarInfo, AbstractVarInfo, link
7+
using DynamicPPL:
8+
Model,
9+
LogDensityFunction,
10+
VarInfo,
11+
AbstractVarInfo,
12+
link,
13+
DefaultContext,
14+
AbstractContext
815
using LogDensityProblems: logdensity, logdensity_and_gradient
916
using Random: Random, Xoshiro
1017
using Statistics: median
@@ -53,6 +60,8 @@ struct ADResult{Tparams<:AbstractFloat,Tresult<:AbstractFloat}
5360
model::Model
5461
"The VarInfo that was used"
5562
varinfo::AbstractVarInfo
63+
"The evaluation context that was used"
64+
context::AbstractContext
5665
"The values at which the model was evaluated"
5766
params::Vector{Tparams}
5867
"The AD backend that was tested"
@@ -83,6 +92,7 @@ end
8392
grad_atol=1e-6,
8493
varinfo::AbstractVarInfo=link(VarInfo(model), model),
8594
params::Union{Nothing,Vector{<:AbstractFloat}}=nothing,
95+
context::AbstractContext=DefaultContext(),
8696
reference_adtype::ADTypes.AbstractADType=REFERENCE_ADTYPE,
8797
expected_value_and_grad::Union{Nothing,Tuple{AbstractFloat,Vector{<:AbstractFloat}}}=nothing,
8898
verbose=true,
@@ -136,7 +146,13 @@ Everything else is optional, and can be categorised into several groups:
136146
prep_params)`. You could then evaluate the gradient at a different set of
137147
parameters using the `params` keyword argument.
138148
139-
3. _How to specify the results to compare against._ (Only if `test=true`.)
149+
3. _How to specify the evaluation context._
150+
151+
A `DynamicPPL.AbstractContext` can be passed as the `context` keyword
152+
argument to control the evaluation context. This defaults to
153+
`DefaultContext()`.
154+
155+
4. _How to specify the results to compare against._ (Only if `test=true`.)
140156
141157
Once logp and its gradient has been calculated with the specified `adtype`,
142158
it must be tested for correctness.
@@ -151,12 +167,12 @@ Everything else is optional, and can be categorised into several groups:
151167
The default reference backend is ForwardDiff. If none of these parameters are
152168
specified, ForwardDiff will be used to calculate the ground truth.
153169
154-
4. _How to specify the tolerances._ (Only if `test=true`.)
170+
5. _How to specify the tolerances._ (Only if `test=true`.)
155171
156172
The tolerances for the value and gradient can be set using `value_atol` and
157173
`grad_atol`. These default to 1e-6.
158174
159-
5. _Whether to output extra logging information._
175+
6. _Whether to output extra logging information._
160176
161177
By default, this function prints messages when it runs. To silence it, set
162178
`verbose=false`.
@@ -179,6 +195,7 @@ function run_ad(
179195
grad_atol::AbstractFloat=1e-6,
180196
varinfo::AbstractVarInfo=link(VarInfo(model), model),
181197
params::Union{Nothing,Vector{<:AbstractFloat}}=nothing,
198+
context::AbstractContext=DefaultContext(),
182199
reference_adtype::AbstractADType=REFERENCE_ADTYPE,
183200
expected_value_and_grad::Union{Nothing,Tuple{AbstractFloat,Vector{<:AbstractFloat}}}=nothing,
184201
verbose=true,
@@ -190,7 +207,7 @@ function run_ad(
190207

191208
verbose && @info "Running AD on $(model.f) with $(adtype)\n"
192209
verbose && println(" params : $(params)")
193-
ldf = LogDensityFunction(model, varinfo; adtype=adtype)
210+
ldf = LogDensityFunction(model, varinfo, context; adtype=adtype)
194211

195212
value, grad = logdensity_and_gradient(ldf, params)
196213
grad = collect(grad)
@@ -199,7 +216,7 @@ function run_ad(
199216
if test
200217
# Calculate ground truth to compare against
201218
value_true, grad_true = if expected_value_and_grad === nothing
202-
ldf_reference = LogDensityFunction(model, varinfo; adtype=reference_adtype)
219+
ldf_reference = LogDensityFunction(model, varinfo, context; adtype=reference_adtype)
203220
logdensity_and_gradient(ldf_reference, params)
204221
else
205222
expected_value_and_grad
@@ -228,6 +245,7 @@ function run_ad(
228245
return ADResult(
229246
model,
230247
varinfo,
248+
context,
231249
params,
232250
adtype,
233251
value_atol,

0 commit comments

Comments
 (0)