diff --git a/HISTORY.md b/HISTORY.md index d1f8c2ba5..c6a8c1bec 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -1,5 +1,9 @@ # DynamicPPL Changelog +## 0.36.6 + +`DynamicPPL.TestUtils.run_ad` now takes an extra `context` keyword argument, which is passed to the `LogDensityFunction` constructor. + ## 0.36.5 `varinfo[:]` now returns an empty vector if `varinfo::DynamicPPL.NTVarInfo` is empty, rather than erroring. diff --git a/Project.toml b/Project.toml index b9253d2a5..e199637e6 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" -version = "0.36.5" +version = "0.36.6" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/src/test_utils/ad.jl b/src/test_utils/ad.jl index 8c926723a..0c267c1c5 100644 --- a/src/test_utils/ad.jl +++ b/src/test_utils/ad.jl @@ -4,7 +4,14 @@ using ADTypes: AbstractADType, AutoForwardDiff using Chairmarks: @be import DifferentiationInterface as DI using DocStringExtensions -using DynamicPPL: Model, LogDensityFunction, VarInfo, AbstractVarInfo, link +using DynamicPPL: + Model, + LogDensityFunction, + VarInfo, + AbstractVarInfo, + link, + DefaultContext, + AbstractContext using LogDensityProblems: logdensity, logdensity_and_gradient using Random: Random, Xoshiro using Statistics: median @@ -53,6 +60,8 @@ struct ADResult{Tparams<:AbstractFloat,Tresult<:AbstractFloat} model::Model "The VarInfo that was used" varinfo::AbstractVarInfo + "The evaluation context that was used" + context::AbstractContext "The values at which the model was evaluated" params::Vector{Tparams} "The AD backend that was tested" @@ -83,6 +92,7 @@ end grad_atol=1e-6, varinfo::AbstractVarInfo=link(VarInfo(model), model), params::Union{Nothing,Vector{<:AbstractFloat}}=nothing, + context::AbstractContext=DefaultContext(), reference_adtype::ADTypes.AbstractADType=REFERENCE_ADTYPE, expected_value_and_grad::Union{Nothing,Tuple{AbstractFloat,Vector{<:AbstractFloat}}}=nothing, verbose=true, @@ -136,7 +146,13 @@ Everything else is optional, and can be categorised into several groups: prep_params)`. You could then evaluate the gradient at a different set of parameters using the `params` keyword argument. -3. _How to specify the results to compare against._ (Only if `test=true`.) +3. _How to specify the evaluation context._ + + A `DynamicPPL.AbstractContext` can be passed as the `context` keyword + argument to control the evaluation context. This defaults to + `DefaultContext()`. + +4. _How to specify the results to compare against._ (Only if `test=true`.) Once logp and its gradient has been calculated with the specified `adtype`, it must be tested for correctness. @@ -151,12 +167,12 @@ Everything else is optional, and can be categorised into several groups: The default reference backend is ForwardDiff. If none of these parameters are specified, ForwardDiff will be used to calculate the ground truth. -4. _How to specify the tolerances._ (Only if `test=true`.) +5. _How to specify the tolerances._ (Only if `test=true`.) The tolerances for the value and gradient can be set using `value_atol` and `grad_atol`. These default to 1e-6. -5. _Whether to output extra logging information._ +6. _Whether to output extra logging information._ By default, this function prints messages when it runs. To silence it, set `verbose=false`. @@ -179,6 +195,7 @@ function run_ad( grad_atol::AbstractFloat=1e-6, varinfo::AbstractVarInfo=link(VarInfo(model), model), params::Union{Nothing,Vector{<:AbstractFloat}}=nothing, + context::AbstractContext=DefaultContext(), reference_adtype::AbstractADType=REFERENCE_ADTYPE, expected_value_and_grad::Union{Nothing,Tuple{AbstractFloat,Vector{<:AbstractFloat}}}=nothing, verbose=true, @@ -190,7 +207,7 @@ function run_ad( verbose && @info "Running AD on $(model.f) with $(adtype)\n" verbose && println(" params : $(params)") - ldf = LogDensityFunction(model, varinfo; adtype=adtype) + ldf = LogDensityFunction(model, varinfo, context; adtype=adtype) value, grad = logdensity_and_gradient(ldf, params) grad = collect(grad) @@ -199,7 +216,7 @@ function run_ad( if test # Calculate ground truth to compare against value_true, grad_true = if expected_value_and_grad === nothing - ldf_reference = LogDensityFunction(model, varinfo; adtype=reference_adtype) + ldf_reference = LogDensityFunction(model, varinfo, context; adtype=reference_adtype) logdensity_and_gradient(ldf_reference, params) else expected_value_and_grad @@ -228,6 +245,7 @@ function run_ad( return ADResult( model, varinfo, + context, params, adtype, value_atol,