Skip to content

Link varinfo by default in AD testing utilities; make test suite run on linked varinfos #890

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Apr 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions HISTORY.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,18 @@

**Breaking changes**

### AD testing utilities

`DynamicPPL.TestUtils.AD.run_ad` now links the VarInfo by default.
To disable this, pass the `linked=false` keyword argument.
If the calculated value or gradient is incorrect, it also throws a `DynamicPPL.TestUtils.AD.ADIncorrectException` rather than a test failure.
This exception contains the actual and expected gradient so you can inspect it if needed; see the documentation for more information.
From a practical perspective, this means that if you need to add this to a test suite, you need to use `@test run_ad(...) isa Any` rather than just `run_ad(...)`.

### SimpleVarInfo linking / invlinking

Linking a linked SimpleVarInfo, or invlinking an unlinked SimpleVarInfo, now displays a warning instead of an error.

### VarInfo constructors

`VarInfo(vi::VarInfo, values)` has been removed. You can replace this directly with `unflatten(vi, values)` instead.
Expand Down
1 change: 1 addition & 0 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,7 @@ To test and/or benchmark the performance of an AD backend on a model, DynamicPPL
```@docs
DynamicPPL.TestUtils.AD.run_ad
DynamicPPL.TestUtils.AD.ADResult
DynamicPPL.TestUtils.AD.ADIncorrectException
```

## Demo models
Expand Down
104 changes: 66 additions & 38 deletions src/test_utils/ad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,13 @@
using Chairmarks: @be
import DifferentiationInterface as DI
using DocStringExtensions
using DynamicPPL: Model, LogDensityFunction, VarInfo, AbstractVarInfo
using DynamicPPL: Model, LogDensityFunction, VarInfo, AbstractVarInfo, link
using LogDensityProblems: logdensity, logdensity_and_gradient
using Random: Random, Xoshiro
using Statistics: median
using Test: @test

export ADResult, run_ad

# This function needed to work around the fact that different backends can
# return different AbstractArrays for the gradient. See
# https://github.com/JuliaDiff/DifferentiationInterface.jl/issues/754 for more
# context.
_to_vec_f64(x::AbstractArray) = x isa Vector{Float64} ? x : collect(Float64, x)
export ADResult, run_ad, ADIncorrectException

"""
REFERENCE_ADTYPE
Expand All @@ -27,33 +21,50 @@
const REFERENCE_ADTYPE = AutoForwardDiff()

"""
ADResult
ADIncorrectException{T<:AbstractFloat}

Exception thrown when an AD backend returns an incorrect value or gradient.

The type parameter `T` is the numeric type of the value and gradient.
"""
struct ADIncorrectException{T<:AbstractFloat} <: Exception
value_expected::T
value_actual::T
grad_expected::Vector{T}
grad_actual::Vector{T}
end

"""
ADResult{Tparams<:AbstractFloat,Tresult<:AbstractFloat}

Data structure to store the results of the AD correctness test.

The type parameter `Tparams` is the numeric type of the parameters passed in;
`Tresult` is the type of the value and the gradient.
"""
struct ADResult
struct ADResult{Tparams<:AbstractFloat,Tresult<:AbstractFloat}
"The DynamicPPL model that was tested"
model::Model
"The VarInfo that was used"
varinfo::AbstractVarInfo
"The values at which the model was evaluated"
params::Vector{<:Real}
params::Vector{Tparams}
"The AD backend that was tested"
adtype::AbstractADType
"The absolute tolerance for the value of logp"
value_atol::Real
value_atol::Tresult
"The absolute tolerance for the gradient of logp"
grad_atol::Real
grad_atol::Tresult
"The expected value of logp"
value_expected::Union{Nothing,Float64}
value_expected::Union{Nothing,Tresult}
"The expected gradient of logp"
grad_expected::Union{Nothing,Vector{Float64}}
grad_expected::Union{Nothing,Vector{Tresult}}
"The value of logp (calculated using `adtype`)"
value_actual::Union{Nothing,Real}
value_actual::Union{Nothing,Tresult}
"The gradient of logp (calculated using `adtype`)"
grad_actual::Union{Nothing,Vector{Float64}}
grad_actual::Union{Nothing,Vector{Tresult}}
"If benchmarking was requested, the time taken by the AD backend to calculate the gradient of logp, divided by the time taken to evaluate logp itself"
time_vs_primal::Union{Nothing,Float64}
time_vs_primal::Union{Nothing,Tresult}
end

"""
Expand All @@ -64,26 +75,27 @@
benchmark=false,
value_atol=1e-6,
grad_atol=1e-6,
varinfo::AbstractVarInfo=VarInfo(model),
params::Vector{<:Real}=varinfo[:],
varinfo::AbstractVarInfo=link(VarInfo(model), model),
params::Union{Nothing,Vector{<:AbstractFloat}}=nothing,
reference_adtype::ADTypes.AbstractADType=REFERENCE_ADTYPE,
expected_value_and_grad::Union{Nothing,Tuple{Real,Vector{<:Real}}}=nothing,
expected_value_and_grad::Union{Nothing,Tuple{AbstractFloat,Vector{<:AbstractFloat}}}=nothing,
verbose=true,
)::ADResult

### Description

Test the correctness and/or benchmark the AD backend `adtype` for the model
`model`.

Whether to test and benchmark is controlled by the `test` and `benchmark`
keyword arguments. By default, `test` is `true` and `benchmark` is `false`.

Returns an [`ADResult`](@ref) object, which contains the results of the
test and/or benchmark.

Note that to run AD successfully you will need to import the AD backend itself.
For example, to test with `AutoReverseDiff()` you will need to run `import
ReverseDiff`.

### Arguments

There are two positional arguments, which absolutely must be provided:

1. `model` - The model being tested.
Expand All @@ -96,7 +108,9 @@
DynamicPPL contains several different types of VarInfo objects which change
the way model evaluation occurs. If you want to use a specific type of
VarInfo, pass it as the `varinfo` argument. Otherwise, it will default to
using a `TypedVarInfo` generated from the model.
using a linked `TypedVarInfo` generated from the model. Here, _linked_
means that the parameters in the VarInfo have been transformed to
unconstrained Euclidean space if they aren't already in that space.

2. _How to specify the parameters._

Expand Down Expand Up @@ -140,27 +154,40 @@

By default, this function prints messages when it runs. To silence it, set
`verbose=false`.

### Returns / Throws

Returns an [`ADResult`](@ref) object, which contains the results of the
test and/or benchmark.

If `test` is `true` and the AD backend returns an incorrect value or gradient, an
`ADIncorrectException` is thrown. If a different error occurs, it will be
thrown as-is.
"""
function run_ad(
model::Model,
adtype::AbstractADType;
test=true,
benchmark=false,
value_atol=1e-6,
grad_atol=1e-6,
varinfo::AbstractVarInfo=VarInfo(model),
params::Vector{<:Real}=varinfo[:],
test::Bool=true,
benchmark::Bool=false,
value_atol::AbstractFloat=1e-6,
grad_atol::AbstractFloat=1e-6,
varinfo::AbstractVarInfo=link(VarInfo(model), model),
params::Union{Nothing,Vector{<:AbstractFloat}}=nothing,
reference_adtype::AbstractADType=REFERENCE_ADTYPE,
expected_value_and_grad::Union{Nothing,Tuple{Real,Vector{<:Real}}}=nothing,
expected_value_and_grad::Union{Nothing,Tuple{AbstractFloat,Vector{<:AbstractFloat}}}=nothing,
verbose=true,
)::ADResult
if isnothing(params)
params = varinfo[:]

Check warning on line 181 in src/test_utils/ad.jl

View check run for this annotation

Codecov / codecov/patch

src/test_utils/ad.jl#L180-L181

Added lines #L180 - L181 were not covered by tests
end
params = map(identity, params) # Concretise

Check warning on line 183 in src/test_utils/ad.jl

View check run for this annotation

Codecov / codecov/patch

src/test_utils/ad.jl#L183

Added line #L183 was not covered by tests

verbose && @info "Running AD on $(model.f) with $(adtype)\n"
params = map(identity, params)
verbose && println(" params : $(params)")
ldf = LogDensityFunction(model, varinfo; adtype=adtype)

value, grad = logdensity_and_gradient(ldf, params)
grad = _to_vec_f64(grad)
grad = collect(grad)

Check warning on line 190 in src/test_utils/ad.jl

View check run for this annotation

Codecov / codecov/patch

src/test_utils/ad.jl#L190

Added line #L190 was not covered by tests
verbose && println(" actual : $((value, grad))")

if test
Expand All @@ -172,10 +199,11 @@
expected_value_and_grad
end
verbose && println(" expected : $((value_true, grad_true))")
grad_true = _to_vec_f64(grad_true)
# Then compare
@test isapprox(value, value_true; atol=value_atol)
@test isapprox(grad, grad_true; atol=grad_atol)
grad_true = collect(grad_true)

Check warning on line 202 in src/test_utils/ad.jl

View check run for this annotation

Codecov / codecov/patch

src/test_utils/ad.jl#L202

Added line #L202 was not covered by tests

exc() = throw(ADIncorrectException(value, value_true, grad, grad_true))
isapprox(value, value_true; atol=value_atol) || exc()
isapprox(grad, grad_true; atol=grad_atol) || exc()

Check warning on line 206 in src/test_utils/ad.jl

View check run for this annotation

Codecov / codecov/patch

src/test_utils/ad.jl#L204-L206

Added lines #L204 - L206 were not covered by tests
else
value_true = nothing
grad_true = nothing
Expand Down
4 changes: 2 additions & 2 deletions src/transforming.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@ function tilde_assume(
lp = Bijectors.logpdf_with_trans(right, r, !isinverse)

if istrans(vi, vn)
@assert isinverse "Trying to link already transformed variables"
isinverse || @warn "Trying to link an already transformed variable ($vn)"
else
@assert !isinverse "Trying to invlink non-transformed variables"
isinverse && @warn "Trying to invlink a non-transformed variable ($vn)"
end

# Only transform if `!isinverse` since `vi[vn, right]`
Expand Down
18 changes: 10 additions & 8 deletions test/ad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,21 +23,23 @@ using DynamicPPL: LogDensityFunction
varinfos = DynamicPPL.TestUtils.setup_varinfos(m, rand_param_values, vns)

@testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos
f = LogDensityFunction(m, varinfo)
linked_varinfo = DynamicPPL.link(varinfo, m)
f = LogDensityFunction(m, linked_varinfo)
x = DynamicPPL.getparams(f)
# Calculate reference logp + gradient of logp using ForwardDiff
ref_ldf = LogDensityFunction(m, varinfo; adtype=ref_adtype)
ref_ldf = LogDensityFunction(m, linked_varinfo; adtype=ref_adtype)
ref_logp, ref_grad = LogDensityProblems.logdensity_and_gradient(ref_ldf, x)

@testset "$adtype" for adtype in test_adtypes
@info "Testing AD on: $(m.f) - $(short_varinfo_name(varinfo)) - $adtype"
@info "Testing AD on: $(m.f) - $(short_varinfo_name(linked_varinfo)) - $adtype"

# Put predicates here to avoid long lines
is_mooncake = adtype isa AutoMooncake
is_1_10 = v"1.10" <= VERSION < v"1.11"
is_1_11 = v"1.11" <= VERSION < v"1.12"
is_svi_vnv = varinfo isa SimpleVarInfo{<:DynamicPPL.VarNamedVector}
is_svi_od = varinfo isa SimpleVarInfo{<:OrderedDict}
is_svi_vnv =
linked_varinfo isa SimpleVarInfo{<:DynamicPPL.VarNamedVector}
is_svi_od = linked_varinfo isa SimpleVarInfo{<:OrderedDict}

# Mooncake doesn't work with several combinations of SimpleVarInfo.
if is_mooncake && is_1_11 && is_svi_vnv
Expand All @@ -56,12 +58,12 @@ using DynamicPPL: LogDensityFunction
ref_ldf, adtype
)
else
DynamicPPL.TestUtils.AD.run_ad(
@test DynamicPPL.TestUtils.AD.run_ad(
m,
adtype;
varinfo=varinfo,
varinfo=linked_varinfo,
expected_value_and_grad=(ref_logp, ref_grad),
)
) isa Any
end
end
end
Expand Down
6 changes: 0 additions & 6 deletions test/simple_varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -111,12 +111,6 @@
# Should be approx. the same as the "lazy" transformation.
@test logjoint(model, vi_linked) ≈ lp_linked

# TODO: Should not `VarInfo` also error here? The current implementation
# only warns and acts as a no-op.
if vi isa SimpleVarInfo
@test_throws AssertionError link!!(vi_linked, model)
end

# `invlink!!`
vi_invlinked = invlink!!(deepcopy(vi_linked), model)
lp_invlinked = getlogp(vi_invlinked)
Expand Down
Loading