diff --git a/HISTORY.md b/HISTORY.md index d559e6373..96c4465ba 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -8,6 +8,17 @@ The `@submodel` macro is fully removed; please use `to_submodel` instead. +### `DynamicPPL.TestUtils.AD.run_ad` + +The three keyword arguments, `test`, `reference_backend`, and `expected_value_and_grad` have been merged into a single `test` keyword argument. +Please see the API documentation for more details. +(The old `test=true` and `test=false` values are still valid, and you only need to adjust the invocation if you were explicitly passing the `reference_backend` or `expected_value_and_grad` arguments.) + +There is now also an `rng` keyword argument to help seed parameter generation. + +Finally, instead of specifying `value_atol` and `grad_atol`, you can now specify `atol` and `rtol` which are used for both value and gradient. +Their semantics are the same as in Julia's `isapprox`; two values are equal if they satisfy either `atol` or `rtol`. + ### Accumulators This release overhauls how VarInfo objects track variables such as the log joint probability. The new approach is to use what we call accumulators: Objects that the VarInfo carries on it that may change their state at each `tilde_assume!!` and `tilde_observe!!` call based on the value of the variable in question. They replace both variables that were previously hard-coded in the `VarInfo` object (`logp` and `num_produce`) and some contexts. This brings with it a number of breaking changes: diff --git a/docs/src/api.md b/docs/src/api.md index 886d34a2f..1419ff044 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -211,6 +211,21 @@ To test and/or benchmark the performance of an AD backend on a model, DynamicPPL ```@docs DynamicPPL.TestUtils.AD.run_ad +``` + +THe default test setting is to compare against ForwardDiff. +You can have more fine-grained control over how to test the AD backend using the following types: + +```@docs +DynamicPPL.TestUtils.AD.AbstractADCorrectnessTestSetting +DynamicPPL.TestUtils.AD.WithBackend +DynamicPPL.TestUtils.AD.WithExpectedResult +DynamicPPL.TestUtils.AD.NoTest +``` + +These are returned / thrown by the `run_ad` function: + +```@docs DynamicPPL.TestUtils.AD.ADResult DynamicPPL.TestUtils.AD.ADIncorrectException ``` diff --git a/src/test_utils/ad.jl b/src/test_utils/ad.jl index 5285391b1..155f3b68d 100644 --- a/src/test_utils/ad.jl +++ b/src/test_utils/ad.jl @@ -4,28 +4,57 @@ using ADTypes: AbstractADType, AutoForwardDiff using Chairmarks: @be import DifferentiationInterface as DI using DocStringExtensions -using DynamicPPL: - Model, - LogDensityFunction, - VarInfo, - AbstractVarInfo, - link, - DefaultContext, - AbstractContext +using DynamicPPL: Model, LogDensityFunction, VarInfo, AbstractVarInfo, link using LogDensityProblems: logdensity, logdensity_and_gradient -using Random: Random, Xoshiro +using Random: AbstractRNG, default_rng using Statistics: median using Test: @test -export ADResult, run_ad, ADIncorrectException +export ADResult, run_ad, ADIncorrectException, WithBackend, WithExpectedResult, NoTest """ - REFERENCE_ADTYPE + AbstractADCorrectnessTestSetting -Reference AD backend to use for comparison. In this case, ForwardDiff.jl, since -it's the default AD backend used in Turing.jl. +Different ways of testing the correctness of an AD backend. """ -const REFERENCE_ADTYPE = AutoForwardDiff() +abstract type AbstractADCorrectnessTestSetting end + +""" + WithBackend(adtype::AbstractADType=AutoForwardDiff()) <: AbstractADCorrectnessTestSetting + +Test correctness by comparing it against the result obtained with `adtype`. + +`adtype` defaults to ForwardDiff.jl, since it's the default AD backend used in +Turing.jl. +""" +struct WithBackend{AD<:AbstractADType} <: AbstractADCorrectnessTestSetting + adtype::AD +end +WithBackend() = WithBackend(AutoForwardDiff()) + +""" + WithExpectedResult( + value::T, + grad::AbstractVector{T} + ) where {T <: AbstractFloat} + <: AbstractADCorrectnessTestSetting + +Test correctness by comparing it against a known result (e.g. one obtained +analytically, or one obtained with a different backend previously). Both the +value of the primal (i.e. the log-density) as well as its gradient must be +supplied. +""" +struct WithExpectedResult{T<:AbstractFloat} <: AbstractADCorrectnessTestSetting + value::T + grad::AbstractVector{T} +end + +""" + NoTest() <: AbstractADCorrectnessTestSetting + +Disable correctness testing. +""" +struct NoTest <: AbstractADCorrectnessTestSetting end """ ADIncorrectException{T<:AbstractFloat} @@ -45,17 +74,18 @@ struct ADIncorrectException{T<:AbstractFloat} <: Exception end """ - ADResult{Tparams<:AbstractFloat,Tresult<:AbstractFloat} + ADResult{Tparams<:AbstractFloat,Tresult<:AbstractFloat,Ttol<:AbstractFloat} Data structure to store the results of the AD correctness test. The type parameter `Tparams` is the numeric type of the parameters passed in; -`Tresult` is the type of the value and the gradient. +`Tresult` is the type of the value and the gradient; and `Ttol` is the type of the +absolute and relative tolerances used for correctness testing. # Fields $(TYPEDFIELDS) """ -struct ADResult{Tparams<:AbstractFloat,Tresult<:AbstractFloat} +struct ADResult{Tparams<:AbstractFloat,Tresult<:AbstractFloat,Ttol<:AbstractFloat} "The DynamicPPL model that was tested" model::Model "The VarInfo that was used" @@ -64,18 +94,18 @@ struct ADResult{Tparams<:AbstractFloat,Tresult<:AbstractFloat} params::Vector{Tparams} "The AD backend that was tested" adtype::AbstractADType - "The absolute tolerance for the value of logp" - value_atol::Tresult - "The absolute tolerance for the gradient of logp" - grad_atol::Tresult + "Absolute tolerance used for correctness test" + atol::Ttol + "Relative tolerance used for correctness test" + rtol::Ttol "The expected value of logp" value_expected::Union{Nothing,Tresult} "The expected gradient of logp" grad_expected::Union{Nothing,Vector{Tresult}} "The value of logp (calculated using `adtype`)" - value_actual::Union{Nothing,Tresult} + value_actual::Tresult "The gradient of logp (calculated using `adtype`)" - grad_actual::Union{Nothing,Vector{Tresult}} + grad_actual::Vector{Tresult} "If benchmarking was requested, the time taken by the AD backend to calculate the gradient of logp, divided by the time taken to evaluate logp itself" time_vs_primal::Union{Nothing,Tresult} end @@ -84,14 +114,12 @@ end run_ad( model::Model, adtype::ADTypes.AbstractADType; - test=true, + test::Union{AbstractADCorrectnessTestSetting,Bool}=WithBackend(), benchmark=false, - value_atol=1e-6, - grad_atol=1e-6, + atol::AbstractFloat=1e-8, + rtol::AbstractFloat=sqrt(eps()), varinfo::AbstractVarInfo=link(VarInfo(model), model), params::Union{Nothing,Vector{<:AbstractFloat}}=nothing, - reference_adtype::ADTypes.AbstractADType=REFERENCE_ADTYPE, - expected_value_and_grad::Union{Nothing,Tuple{AbstractFloat,Vector{<:AbstractFloat}}}=nothing, verbose=true, )::ADResult @@ -133,8 +161,8 @@ Everything else is optional, and can be categorised into several groups: Note that if the VarInfo is not specified (and thus automatically generated) the parameters in it will have been sampled from the prior of the model. If - you want to seed the parameter generation, the easiest way is to pass a - `rng` argument to the VarInfo constructor (i.e. do `VarInfo(rng, model)`). + you want to seed the parameter generation for the VarInfo, you can pass the + `rng` keyword argument, which will then be used to create the VarInfo. Finally, note that these only reflect the parameters used for _evaluating_ the gradient. If you also want to control the parameters used for @@ -143,25 +171,35 @@ Everything else is optional, and can be categorised into several groups: prep_params)`. You could then evaluate the gradient at a different set of parameters using the `params` keyword argument. -3. _How to specify the results to compare against._ (Only if `test=true`.) +3. _How to specify the results to compare against._ Once logp and its gradient has been calculated with the specified `adtype`, - it must be tested for correctness. + it can optionally be tested for correctness. The exact way this is tested + is specified in the `test` parameter. + + There are several options for this: - This can be done either by specifying `reference_adtype`, in which case logp - and its gradient will also be calculated with this reference in order to - obtain the ground truth; or by using `expected_value_and_grad`, which is a - tuple of `(logp, gradient)` that the calculated values must match. The - latter is useful if you are testing multiple AD backends and want to avoid - recalculating the ground truth multiple times. + - You can explicitly specify the correct value using + [`WithExpectedResult()`](@ref). + - You can compare against the result obtained with a different AD backend + using [`WithBackend(adtype)`](@ref). + - You can disable testing by passing [`NoTest()`](@ref). + - The default is to compare against the result obtained with ForwardDiff, + i.e. `WithBackend(AutoForwardDiff())`. + - `test=false` and `test=true` are synonyms for + `NoTest()` and `WithBackend(AutoForwardDiff())`, respectively. - The default reference backend is ForwardDiff. If none of these parameters are - specified, ForwardDiff will be used to calculate the ground truth. +4. _How to specify the tolerances._ (Only if testing is enabled.) -4. _How to specify the tolerances._ (Only if `test=true`.) + Both absolute and relative tolerances can be specified using the `atol` and + `rtol` keyword arguments respectively. The behaviour of these is similar to + `isapprox()`, i.e. the value and gradient are considered correct if either + atol or rtol is satisfied. The default values are `100*eps()` for `atol` and + `sqrt(eps())` for `rtol`. - The tolerances for the value and gradient can be set using `value_atol` and - `grad_atol`. These default to 1e-6. + For the most part, it is the `rtol` check that is more meaningful, because + we cannot know the magnitude of logp and its gradient a priori. The `atol` + value is supplied to handle the case where gradients are equal to zero. 5. _Whether to output extra logging information._ @@ -180,48 +218,58 @@ thrown as-is. function run_ad( model::Model, adtype::AbstractADType; - test::Bool=true, + test::Union{AbstractADCorrectnessTestSetting,Bool}=WithBackend(), benchmark::Bool=false, - value_atol::AbstractFloat=1e-6, - grad_atol::AbstractFloat=1e-6, - varinfo::AbstractVarInfo=link(VarInfo(model), model), + atol::AbstractFloat=100 * eps(), + rtol::AbstractFloat=sqrt(eps()), + rng::AbstractRNG=default_rng(), + varinfo::AbstractVarInfo=link(VarInfo(rng, model), model), params::Union{Nothing,Vector{<:AbstractFloat}}=nothing, - reference_adtype::AbstractADType=REFERENCE_ADTYPE, - expected_value_and_grad::Union{Nothing,Tuple{AbstractFloat,Vector{<:AbstractFloat}}}=nothing, verbose=true, )::ADResult + # Convert Boolean `test` to an AbstractADCorrectnessTestSetting + if test isa Bool + test = test ? WithBackend() : NoTest() + end + + # Extract parameters if isnothing(params) params = varinfo[:] end params = map(identity, params) # Concretise + # Calculate log-density and gradient with the backend of interest verbose && @info "Running AD on $(model.f) with $(adtype)\n" verbose && println(" params : $(params)") ldf = LogDensityFunction(model, varinfo; adtype=adtype) - value, grad = logdensity_and_gradient(ldf, params) + # collect(): https://github.com/JuliaDiff/DifferentiationInterface.jl/issues/754 grad = collect(grad) verbose && println(" actual : $((value, grad))") - if test - # Calculate ground truth to compare against - value_true, grad_true = if expected_value_and_grad === nothing - ldf_reference = LogDensityFunction(model, varinfo; adtype=reference_adtype) - logdensity_and_gradient(ldf_reference, params) - else - expected_value_and_grad + # Test correctness + if test isa NoTest + value_true = nothing + grad_true = nothing + else + # Get the correct result + if test isa WithExpectedResult + value_true = test.value + grad_true = test.grad + elseif test isa WithBackend + ldf_reference = LogDensityFunction(model, varinfo; adtype=test.adtype) + value_true, grad_true = logdensity_and_gradient(ldf_reference, params) + # collect(): https://github.com/JuliaDiff/DifferentiationInterface.jl/issues/754 + grad_true = collect(grad_true) end + # Perform testing verbose && println(" expected : $((value_true, grad_true))") - grad_true = collect(grad_true) - exc() = throw(ADIncorrectException(value, value_true, grad, grad_true)) - isapprox(value, value_true; atol=value_atol) || exc() - isapprox(grad, grad_true; atol=grad_atol) || exc() - else - value_true = nothing - grad_true = nothing + isapprox(value, value_true; atol=atol, rtol=rtol) || exc() + isapprox(grad, grad_true; atol=atol, rtol=rtol) || exc() end + # Benchmark time_vs_primal = if benchmark primal_benchmark = @be (ldf, params) logdensity(_[1], _[2]) grad_benchmark = @be (ldf, params) logdensity_and_gradient(_[1], _[2]) @@ -237,8 +285,8 @@ function run_ad( varinfo, params, adtype, - value_atol, - grad_atol, + atol, + rtol, value_true, grad_true, value, diff --git a/test/ad.jl b/test/ad.jl index 0947c017a..48dffeadb 100644 --- a/test/ad.jl +++ b/test/ad.jl @@ -1,4 +1,5 @@ using DynamicPPL: LogDensityFunction +using DynamicPPL.TestUtils.AD: run_ad, WithExpectedResult, NoTest @testset "Automatic differentiation" begin # Used as the ground truth that others are compared against. @@ -31,9 +32,10 @@ using DynamicPPL: LogDensityFunction linked_varinfo = DynamicPPL.link(varinfo, m) f = LogDensityFunction(m, linked_varinfo) x = DynamicPPL.getparams(f) + # Calculate reference logp + gradient of logp using ForwardDiff - ref_ldf = LogDensityFunction(m, linked_varinfo; adtype=ref_adtype) - ref_logp, ref_grad = LogDensityProblems.logdensity_and_gradient(ref_ldf, x) + ref_ad_result = run_ad(m, ref_adtype; varinfo=linked_varinfo, test=NoTest()) + ref_logp, ref_grad = ref_ad_result.value_actual, ref_ad_result.grad_actual @testset "$adtype" for adtype in test_adtypes @info "Testing AD on: $(m.f) - $(short_varinfo_name(linked_varinfo)) - $adtype" @@ -50,24 +52,24 @@ using DynamicPPL: LogDensityFunction if is_mooncake && is_1_11 && is_svi_vnv # https://github.com/compintell/Mooncake.jl/issues/470 @test_throws ArgumentError DynamicPPL.LogDensityFunction( - ref_ldf, adtype + m, linked_varinfo; adtype=adtype ) elseif is_mooncake && is_1_10 && is_svi_vnv # TODO: report upstream @test_throws UndefRefError DynamicPPL.LogDensityFunction( - ref_ldf, adtype + m, linked_varinfo; adtype=adtype ) elseif is_mooncake && is_1_10 && is_svi_od # TODO: report upstream @test_throws Mooncake.MooncakeRuleCompilationError DynamicPPL.LogDensityFunction( - ref_ldf, adtype + m, linked_varinfo; adtype=adtype ) else - @test DynamicPPL.TestUtils.AD.run_ad( + @test run_ad( m, adtype; varinfo=linked_varinfo, - expected_value_and_grad=(ref_logp, ref_grad), + test=WithExpectedResult(ref_logp, ref_grad), ) isa Any end end