Skip to content

Improve API for AD testing #964

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: breaking
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions HISTORY.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,17 @@

The `@submodel` macro is fully removed; please use `to_submodel` instead.

### `DynamicPPL.TestUtils.AD.run_ad`

The three keyword arguments, `test`, `reference_backend`, and `expected_value_and_grad` have been merged into a single `test` keyword argument.
Please see the API documentation for more details.
(The old `test=true` and `test=false` values are still valid, and you only need to adjust the invocation if you were explicitly passing the `reference_backend` or `expected_value_and_grad` arguments.)

There is now also an `rng` keyword argument to help seed parameter generation.

Finally, instead of specifying `value_atol` and `grad_atol`, you can now specify `atol` and `rtol` which are used for both value and gradient.
Their semantics are the same as in Julia's `isapprox`; two values are equal if they satisfy either `atol` or `rtol`.

### Accumulators

This release overhauls how VarInfo objects track variables such as the log joint probability. The new approach is to use what we call accumulators: Objects that the VarInfo carries on it that may change their state at each `tilde_assume!!` and `tilde_observe!!` call based on the value of the variable in question. They replace both variables that were previously hard-coded in the `VarInfo` object (`logp` and `num_produce`) and some contexts. This brings with it a number of breaking changes:
Expand Down
15 changes: 15 additions & 0 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,21 @@ To test and/or benchmark the performance of an AD backend on a model, DynamicPPL

```@docs
DynamicPPL.TestUtils.AD.run_ad
```

THe default test setting is to compare against ForwardDiff.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
THe default test setting is to compare against ForwardDiff.
The default test setting is to compare against ForwardDiff.

You can have more fine-grained control over how to test the AD backend using the following types:

```@docs
DynamicPPL.TestUtils.AD.AbstractADCorrectnessTestSetting
DynamicPPL.TestUtils.AD.WithBackend
DynamicPPL.TestUtils.AD.WithExpectedResult
DynamicPPL.TestUtils.AD.NoTest
```

These are returned / thrown by the `run_ad` function:

```@docs
DynamicPPL.TestUtils.AD.ADResult
DynamicPPL.TestUtils.AD.ADIncorrectException
```
Expand Down
180 changes: 114 additions & 66 deletions src/test_utils/ad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,28 +4,57 @@
using Chairmarks: @be
import DifferentiationInterface as DI
using DocStringExtensions
using DynamicPPL:
Model,
LogDensityFunction,
VarInfo,
AbstractVarInfo,
link,
DefaultContext,
AbstractContext
using DynamicPPL: Model, LogDensityFunction, VarInfo, AbstractVarInfo, link
using LogDensityProblems: logdensity, logdensity_and_gradient
using Random: Random, Xoshiro
using Random: AbstractRNG, default_rng
using Statistics: median
using Test: @test

export ADResult, run_ad, ADIncorrectException
export ADResult, run_ad, ADIncorrectException, WithBackend, WithExpectedResult, NoTest

"""
REFERENCE_ADTYPE
AbstractADCorrectnessTestSetting

Reference AD backend to use for comparison. In this case, ForwardDiff.jl, since
it's the default AD backend used in Turing.jl.
Different ways of testing the correctness of an AD backend.
"""
const REFERENCE_ADTYPE = AutoForwardDiff()
abstract type AbstractADCorrectnessTestSetting end

"""
WithBackend(adtype::AbstractADType=AutoForwardDiff()) <: AbstractADCorrectnessTestSetting

Test correctness by comparing it against the result obtained with `adtype`.

`adtype` defaults to ForwardDiff.jl, since it's the default AD backend used in
Turing.jl.
Comment on lines +27 to +28
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have we considered using FiniteDifferences.jl instead?

"""
struct WithBackend{AD<:AbstractADType} <: AbstractADCorrectnessTestSetting
adtype::AD
end
WithBackend() = WithBackend(AutoForwardDiff())

Check warning on line 33 in src/test_utils/ad.jl

View check run for this annotation

Codecov / codecov/patch

src/test_utils/ad.jl#L33

Added line #L33 was not covered by tests

"""
WithExpectedResult(
value::T,
grad::AbstractVector{T}
) where {T <: AbstractFloat}
<: AbstractADCorrectnessTestSetting

Test correctness by comparing it against a known result (e.g. one obtained
analytically, or one obtained with a different backend previously). Both the
value of the primal (i.e. the log-density) as well as its gradient must be
supplied.
"""
struct WithExpectedResult{T<:AbstractFloat} <: AbstractADCorrectnessTestSetting
value::T
grad::AbstractVector{T}
end

"""
NoTest() <: AbstractADCorrectnessTestSetting

Disable correctness testing.
"""
struct NoTest <: AbstractADCorrectnessTestSetting end

"""
ADIncorrectException{T<:AbstractFloat}
Expand All @@ -45,17 +74,18 @@
end

"""
ADResult{Tparams<:AbstractFloat,Tresult<:AbstractFloat}
ADResult{Tparams<:AbstractFloat,Tresult<:AbstractFloat,Ttol<:AbstractFloat}

Data structure to store the results of the AD correctness test.

The type parameter `Tparams` is the numeric type of the parameters passed in;
`Tresult` is the type of the value and the gradient.
`Tresult` is the type of the value and the gradient; and `Ttol` is the type of the
absolute and relative tolerances used for correctness testing.

# Fields
$(TYPEDFIELDS)
"""
struct ADResult{Tparams<:AbstractFloat,Tresult<:AbstractFloat}
struct ADResult{Tparams<:AbstractFloat,Tresult<:AbstractFloat,Ttol<:AbstractFloat}
"The DynamicPPL model that was tested"
model::Model
"The VarInfo that was used"
Expand All @@ -64,18 +94,18 @@
params::Vector{Tparams}
"The AD backend that was tested"
adtype::AbstractADType
"The absolute tolerance for the value of logp"
value_atol::Tresult
"The absolute tolerance for the gradient of logp"
grad_atol::Tresult
"Absolute tolerance used for correctness test"
atol::Ttol
"Relative tolerance used for correctness test"
rtol::Ttol
"The expected value of logp"
value_expected::Union{Nothing,Tresult}
"The expected gradient of logp"
grad_expected::Union{Nothing,Vector{Tresult}}
"The value of logp (calculated using `adtype`)"
value_actual::Union{Nothing,Tresult}
value_actual::Tresult
"The gradient of logp (calculated using `adtype`)"
grad_actual::Union{Nothing,Vector{Tresult}}
grad_actual::Vector{Tresult}
"If benchmarking was requested, the time taken by the AD backend to calculate the gradient of logp, divided by the time taken to evaluate logp itself"
time_vs_primal::Union{Nothing,Tresult}
end
Expand All @@ -84,14 +114,12 @@
run_ad(
model::Model,
adtype::ADTypes.AbstractADType;
test=true,
test::Union{AbstractADCorrectnessTestSetting,Bool}=WithBackend(),
benchmark=false,
value_atol=1e-6,
grad_atol=1e-6,
atol::AbstractFloat=1e-8,
rtol::AbstractFloat=sqrt(eps()),
varinfo::AbstractVarInfo=link(VarInfo(model), model),
params::Union{Nothing,Vector{<:AbstractFloat}}=nothing,
reference_adtype::ADTypes.AbstractADType=REFERENCE_ADTYPE,
expected_value_and_grad::Union{Nothing,Tuple{AbstractFloat,Vector{<:AbstractFloat}}}=nothing,
verbose=true,
)::ADResult

Expand Down Expand Up @@ -133,8 +161,8 @@

Note that if the VarInfo is not specified (and thus automatically generated)
the parameters in it will have been sampled from the prior of the model. If
you want to seed the parameter generation, the easiest way is to pass a
`rng` argument to the VarInfo constructor (i.e. do `VarInfo(rng, model)`).
you want to seed the parameter generation for the VarInfo, you can pass the
`rng` keyword argument, which will then be used to create the VarInfo.

Finally, note that these only reflect the parameters used for _evaluating_
the gradient. If you also want to control the parameters used for
Expand All @@ -143,25 +171,35 @@
prep_params)`. You could then evaluate the gradient at a different set of
parameters using the `params` keyword argument.

3. _How to specify the results to compare against._ (Only if `test=true`.)
3. _How to specify the results to compare against._

Once logp and its gradient has been calculated with the specified `adtype`,
it must be tested for correctness.
it can optionally be tested for correctness. The exact way this is tested
is specified in the `test` parameter.

There are several options for this:

This can be done either by specifying `reference_adtype`, in which case logp
and its gradient will also be calculated with this reference in order to
obtain the ground truth; or by using `expected_value_and_grad`, which is a
tuple of `(logp, gradient)` that the calculated values must match. The
latter is useful if you are testing multiple AD backends and want to avoid
recalculating the ground truth multiple times.
- You can explicitly specify the correct value using
[`WithExpectedResult()`](@ref).
- You can compare against the result obtained with a different AD backend
using [`WithBackend(adtype)`](@ref).
- You can disable testing by passing [`NoTest()`](@ref).
- The default is to compare against the result obtained with ForwardDiff,
i.e. `WithBackend(AutoForwardDiff())`.
- `test=false` and `test=true` are synonyms for
`NoTest()` and `WithBackend(AutoForwardDiff())`, respectively.

The default reference backend is ForwardDiff. If none of these parameters are
specified, ForwardDiff will be used to calculate the ground truth.
4. _How to specify the tolerances._ (Only if testing is enabled.)

4. _How to specify the tolerances._ (Only if `test=true`.)
Both absolute and relative tolerances can be specified using the `atol` and
`rtol` keyword arguments respectively. The behaviour of these is similar to
`isapprox()`, i.e. the value and gradient are considered correct if either
atol or rtol is satisfied. The default values are `100*eps()` for `atol` and
`sqrt(eps())` for `rtol`.

The tolerances for the value and gradient can be set using `value_atol` and
`grad_atol`. These default to 1e-6.
For the most part, it is the `rtol` check that is more meaningful, because
we cannot know the magnitude of logp and its gradient a priori. The `atol`
value is supplied to handle the case where gradients are equal to zero.

5. _Whether to output extra logging information._

Expand All @@ -180,48 +218,58 @@
function run_ad(
model::Model,
adtype::AbstractADType;
test::Bool=true,
test::Union{AbstractADCorrectnessTestSetting,Bool}=WithBackend(),
benchmark::Bool=false,
value_atol::AbstractFloat=1e-6,
grad_atol::AbstractFloat=1e-6,
varinfo::AbstractVarInfo=link(VarInfo(model), model),
atol::AbstractFloat=100 * eps(),
rtol::AbstractFloat=sqrt(eps()),
rng::AbstractRNG=default_rng(),
varinfo::AbstractVarInfo=link(VarInfo(rng, model), model),
params::Union{Nothing,Vector{<:AbstractFloat}}=nothing,
reference_adtype::AbstractADType=REFERENCE_ADTYPE,
expected_value_and_grad::Union{Nothing,Tuple{AbstractFloat,Vector{<:AbstractFloat}}}=nothing,
verbose=true,
)::ADResult
# Convert Boolean `test` to an AbstractADCorrectnessTestSetting
if test isa Bool
test = test ? WithBackend() : NoTest()

Check warning on line 232 in src/test_utils/ad.jl

View check run for this annotation

Codecov / codecov/patch

src/test_utils/ad.jl#L232

Added line #L232 was not covered by tests
end

# Extract parameters
if isnothing(params)
params = varinfo[:]
end
params = map(identity, params) # Concretise

# Calculate log-density and gradient with the backend of interest
verbose && @info "Running AD on $(model.f) with $(adtype)\n"
verbose && println(" params : $(params)")
ldf = LogDensityFunction(model, varinfo; adtype=adtype)

value, grad = logdensity_and_gradient(ldf, params)
# collect(): https://github.com/JuliaDiff/DifferentiationInterface.jl/issues/754
grad = collect(grad)
verbose && println(" actual : $((value, grad))")

if test
# Calculate ground truth to compare against
value_true, grad_true = if expected_value_and_grad === nothing
ldf_reference = LogDensityFunction(model, varinfo; adtype=reference_adtype)
logdensity_and_gradient(ldf_reference, params)
else
expected_value_and_grad
# Test correctness
if test isa NoTest
value_true = nothing
grad_true = nothing
Comment on lines +251 to +253
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do these values actually ever get used?

else
# Get the correct result
if test isa WithExpectedResult
value_true = test.value
grad_true = test.grad
elseif test isa WithBackend
ldf_reference = LogDensityFunction(model, varinfo; adtype=test.adtype)
value_true, grad_true = logdensity_and_gradient(ldf_reference, params)

Check warning on line 261 in src/test_utils/ad.jl

View check run for this annotation

Codecov / codecov/patch

src/test_utils/ad.jl#L259-L261

Added lines #L259 - L261 were not covered by tests
# collect(): https://github.com/JuliaDiff/DifferentiationInterface.jl/issues/754
grad_true = collect(grad_true)

Check warning on line 263 in src/test_utils/ad.jl

View check run for this annotation

Codecov / codecov/patch

src/test_utils/ad.jl#L263

Added line #L263 was not covered by tests
end
# Perform testing
verbose && println(" expected : $((value_true, grad_true))")
grad_true = collect(grad_true)

exc() = throw(ADIncorrectException(value, value_true, grad, grad_true))
isapprox(value, value_true; atol=value_atol) || exc()
isapprox(grad, grad_true; atol=grad_atol) || exc()
else
value_true = nothing
grad_true = nothing
isapprox(value, value_true; atol=atol, rtol=rtol) || exc()
isapprox(grad, grad_true; atol=atol, rtol=rtol) || exc()
end

# Benchmark
time_vs_primal = if benchmark
primal_benchmark = @be (ldf, params) logdensity(_[1], _[2])
grad_benchmark = @be (ldf, params) logdensity_and_gradient(_[1], _[2])
Expand All @@ -237,8 +285,8 @@
varinfo,
params,
adtype,
value_atol,
grad_atol,
atol,
rtol,
value_true,
grad_true,
value,
Expand Down
16 changes: 9 additions & 7 deletions test/ad.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using DynamicPPL: LogDensityFunction
using DynamicPPL.TestUtils.AD: run_ad, WithExpectedResult, NoTest

@testset "Automatic differentiation" begin
# Used as the ground truth that others are compared against.
Expand Down Expand Up @@ -31,9 +32,10 @@ using DynamicPPL: LogDensityFunction
linked_varinfo = DynamicPPL.link(varinfo, m)
f = LogDensityFunction(m, linked_varinfo)
x = DynamicPPL.getparams(f)

# Calculate reference logp + gradient of logp using ForwardDiff
ref_ldf = LogDensityFunction(m, linked_varinfo; adtype=ref_adtype)
ref_logp, ref_grad = LogDensityProblems.logdensity_and_gradient(ref_ldf, x)
ref_ad_result = run_ad(m, ref_adtype; varinfo=linked_varinfo, test=NoTest())
ref_logp, ref_grad = ref_ad_result.value_actual, ref_ad_result.grad_actual

@testset "$adtype" for adtype in test_adtypes
@info "Testing AD on: $(m.f) - $(short_varinfo_name(linked_varinfo)) - $adtype"
Expand All @@ -50,24 +52,24 @@ using DynamicPPL: LogDensityFunction
if is_mooncake && is_1_11 && is_svi_vnv
# https://github.com/compintell/Mooncake.jl/issues/470
@test_throws ArgumentError DynamicPPL.LogDensityFunction(
ref_ldf, adtype
m, linked_varinfo; adtype=adtype
)
elseif is_mooncake && is_1_10 && is_svi_vnv
# TODO: report upstream
@test_throws UndefRefError DynamicPPL.LogDensityFunction(
ref_ldf, adtype
m, linked_varinfo; adtype=adtype
)
elseif is_mooncake && is_1_10 && is_svi_od
# TODO: report upstream
@test_throws Mooncake.MooncakeRuleCompilationError DynamicPPL.LogDensityFunction(
ref_ldf, adtype
m, linked_varinfo; adtype=adtype
)
else
@test DynamicPPL.TestUtils.AD.run_ad(
@test run_ad(
m,
adtype;
varinfo=linked_varinfo,
expected_value_and_grad=(ref_logp, ref_grad),
test=WithExpectedResult(ref_logp, ref_grad),
) isa Any
end
end
Expand Down
Loading