-
Notifications
You must be signed in to change notification settings - Fork 35
Improve API for AD testing #964
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
penelopeysm
wants to merge
7
commits into
breaking
Choose a base branch
from
py/improve-ad-api
base: breaking
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
7 commits
Select commit
Hold shift + click to select a range
c9347b2
Rework API for AD testing
penelopeysm 2f8574e
Fix test
penelopeysm 48464f3
Add `rng` keyword argument
penelopeysm 6da8d57
Use atol and rtol
penelopeysm 3587ce5
remove unbound type parameter (?)
penelopeysm e1043ae
Don't need to do elementwise check
penelopeysm be36626
Update changelog
penelopeysm File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,28 +4,57 @@ | |
using Chairmarks: @be | ||
import DifferentiationInterface as DI | ||
using DocStringExtensions | ||
using DynamicPPL: | ||
Model, | ||
LogDensityFunction, | ||
VarInfo, | ||
AbstractVarInfo, | ||
link, | ||
DefaultContext, | ||
AbstractContext | ||
using DynamicPPL: Model, LogDensityFunction, VarInfo, AbstractVarInfo, link | ||
using LogDensityProblems: logdensity, logdensity_and_gradient | ||
using Random: Random, Xoshiro | ||
using Random: AbstractRNG, default_rng | ||
using Statistics: median | ||
using Test: @test | ||
|
||
export ADResult, run_ad, ADIncorrectException | ||
export ADResult, run_ad, ADIncorrectException, WithBackend, WithExpectedResult, NoTest | ||
|
||
""" | ||
REFERENCE_ADTYPE | ||
AbstractADCorrectnessTestSetting | ||
|
||
Reference AD backend to use for comparison. In this case, ForwardDiff.jl, since | ||
it's the default AD backend used in Turing.jl. | ||
Different ways of testing the correctness of an AD backend. | ||
""" | ||
const REFERENCE_ADTYPE = AutoForwardDiff() | ||
abstract type AbstractADCorrectnessTestSetting end | ||
|
||
""" | ||
WithBackend(adtype::AbstractADType=AutoForwardDiff()) <: AbstractADCorrectnessTestSetting | ||
|
||
Test correctness by comparing it against the result obtained with `adtype`. | ||
|
||
`adtype` defaults to ForwardDiff.jl, since it's the default AD backend used in | ||
Turing.jl. | ||
Comment on lines
+27
to
+28
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Have we considered using FiniteDifferences.jl instead? |
||
""" | ||
struct WithBackend{AD<:AbstractADType} <: AbstractADCorrectnessTestSetting | ||
adtype::AD | ||
end | ||
WithBackend() = WithBackend(AutoForwardDiff()) | ||
|
||
""" | ||
WithExpectedResult( | ||
value::T, | ||
grad::AbstractVector{T} | ||
) where {T <: AbstractFloat} | ||
<: AbstractADCorrectnessTestSetting | ||
|
||
Test correctness by comparing it against a known result (e.g. one obtained | ||
analytically, or one obtained with a different backend previously). Both the | ||
value of the primal (i.e. the log-density) as well as its gradient must be | ||
supplied. | ||
""" | ||
struct WithExpectedResult{T<:AbstractFloat} <: AbstractADCorrectnessTestSetting | ||
value::T | ||
grad::AbstractVector{T} | ||
end | ||
|
||
""" | ||
NoTest() <: AbstractADCorrectnessTestSetting | ||
|
||
Disable correctness testing. | ||
""" | ||
struct NoTest <: AbstractADCorrectnessTestSetting end | ||
|
||
""" | ||
ADIncorrectException{T<:AbstractFloat} | ||
|
@@ -45,17 +74,18 @@ | |
end | ||
|
||
""" | ||
ADResult{Tparams<:AbstractFloat,Tresult<:AbstractFloat} | ||
ADResult{Tparams<:AbstractFloat,Tresult<:AbstractFloat,Ttol<:AbstractFloat} | ||
|
||
Data structure to store the results of the AD correctness test. | ||
|
||
The type parameter `Tparams` is the numeric type of the parameters passed in; | ||
`Tresult` is the type of the value and the gradient. | ||
`Tresult` is the type of the value and the gradient; and `Ttol` is the type of the | ||
absolute and relative tolerances used for correctness testing. | ||
|
||
# Fields | ||
$(TYPEDFIELDS) | ||
""" | ||
struct ADResult{Tparams<:AbstractFloat,Tresult<:AbstractFloat} | ||
struct ADResult{Tparams<:AbstractFloat,Tresult<:AbstractFloat,Ttol<:AbstractFloat} | ||
"The DynamicPPL model that was tested" | ||
model::Model | ||
"The VarInfo that was used" | ||
|
@@ -64,18 +94,18 @@ | |
params::Vector{Tparams} | ||
"The AD backend that was tested" | ||
adtype::AbstractADType | ||
"The absolute tolerance for the value of logp" | ||
value_atol::Tresult | ||
"The absolute tolerance for the gradient of logp" | ||
grad_atol::Tresult | ||
"Absolute tolerance used for correctness test" | ||
atol::Ttol | ||
"Relative tolerance used for correctness test" | ||
rtol::Ttol | ||
"The expected value of logp" | ||
value_expected::Union{Nothing,Tresult} | ||
"The expected gradient of logp" | ||
grad_expected::Union{Nothing,Vector{Tresult}} | ||
"The value of logp (calculated using `adtype`)" | ||
value_actual::Union{Nothing,Tresult} | ||
value_actual::Tresult | ||
"The gradient of logp (calculated using `adtype`)" | ||
grad_actual::Union{Nothing,Vector{Tresult}} | ||
grad_actual::Vector{Tresult} | ||
"If benchmarking was requested, the time taken by the AD backend to calculate the gradient of logp, divided by the time taken to evaluate logp itself" | ||
time_vs_primal::Union{Nothing,Tresult} | ||
end | ||
|
@@ -84,14 +114,12 @@ | |
run_ad( | ||
model::Model, | ||
adtype::ADTypes.AbstractADType; | ||
test=true, | ||
test::Union{AbstractADCorrectnessTestSetting,Bool}=WithBackend(), | ||
benchmark=false, | ||
value_atol=1e-6, | ||
grad_atol=1e-6, | ||
atol::AbstractFloat=1e-8, | ||
rtol::AbstractFloat=sqrt(eps()), | ||
varinfo::AbstractVarInfo=link(VarInfo(model), model), | ||
params::Union{Nothing,Vector{<:AbstractFloat}}=nothing, | ||
reference_adtype::ADTypes.AbstractADType=REFERENCE_ADTYPE, | ||
expected_value_and_grad::Union{Nothing,Tuple{AbstractFloat,Vector{<:AbstractFloat}}}=nothing, | ||
verbose=true, | ||
)::ADResult | ||
|
||
|
@@ -133,8 +161,8 @@ | |
|
||
Note that if the VarInfo is not specified (and thus automatically generated) | ||
the parameters in it will have been sampled from the prior of the model. If | ||
you want to seed the parameter generation, the easiest way is to pass a | ||
`rng` argument to the VarInfo constructor (i.e. do `VarInfo(rng, model)`). | ||
you want to seed the parameter generation for the VarInfo, you can pass the | ||
`rng` keyword argument, which will then be used to create the VarInfo. | ||
|
||
Finally, note that these only reflect the parameters used for _evaluating_ | ||
the gradient. If you also want to control the parameters used for | ||
|
@@ -143,25 +171,35 @@ | |
prep_params)`. You could then evaluate the gradient at a different set of | ||
parameters using the `params` keyword argument. | ||
|
||
3. _How to specify the results to compare against._ (Only if `test=true`.) | ||
3. _How to specify the results to compare against._ | ||
|
||
Once logp and its gradient has been calculated with the specified `adtype`, | ||
it must be tested for correctness. | ||
it can optionally be tested for correctness. The exact way this is tested | ||
is specified in the `test` parameter. | ||
|
||
There are several options for this: | ||
|
||
This can be done either by specifying `reference_adtype`, in which case logp | ||
and its gradient will also be calculated with this reference in order to | ||
obtain the ground truth; or by using `expected_value_and_grad`, which is a | ||
tuple of `(logp, gradient)` that the calculated values must match. The | ||
latter is useful if you are testing multiple AD backends and want to avoid | ||
recalculating the ground truth multiple times. | ||
- You can explicitly specify the correct value using | ||
[`WithExpectedResult()`](@ref). | ||
- You can compare against the result obtained with a different AD backend | ||
using [`WithBackend(adtype)`](@ref). | ||
- You can disable testing by passing [`NoTest()`](@ref). | ||
- The default is to compare against the result obtained with ForwardDiff, | ||
i.e. `WithBackend(AutoForwardDiff())`. | ||
- `test=false` and `test=true` are synonyms for | ||
`NoTest()` and `WithBackend(AutoForwardDiff())`, respectively. | ||
|
||
The default reference backend is ForwardDiff. If none of these parameters are | ||
specified, ForwardDiff will be used to calculate the ground truth. | ||
4. _How to specify the tolerances._ (Only if testing is enabled.) | ||
|
||
4. _How to specify the tolerances._ (Only if `test=true`.) | ||
Both absolute and relative tolerances can be specified using the `atol` and | ||
`rtol` keyword arguments respectively. The behaviour of these is similar to | ||
`isapprox()`, i.e. the value and gradient are considered correct if either | ||
atol or rtol is satisfied. The default values are `100*eps()` for `atol` and | ||
`sqrt(eps())` for `rtol`. | ||
|
||
The tolerances for the value and gradient can be set using `value_atol` and | ||
`grad_atol`. These default to 1e-6. | ||
For the most part, it is the `rtol` check that is more meaningful, because | ||
we cannot know the magnitude of logp and its gradient a priori. The `atol` | ||
value is supplied to handle the case where gradients are equal to zero. | ||
|
||
5. _Whether to output extra logging information._ | ||
|
||
|
@@ -180,48 +218,58 @@ | |
function run_ad( | ||
model::Model, | ||
adtype::AbstractADType; | ||
test::Bool=true, | ||
test::Union{AbstractADCorrectnessTestSetting,Bool}=WithBackend(), | ||
benchmark::Bool=false, | ||
value_atol::AbstractFloat=1e-6, | ||
grad_atol::AbstractFloat=1e-6, | ||
varinfo::AbstractVarInfo=link(VarInfo(model), model), | ||
atol::AbstractFloat=100 * eps(), | ||
rtol::AbstractFloat=sqrt(eps()), | ||
rng::AbstractRNG=default_rng(), | ||
varinfo::AbstractVarInfo=link(VarInfo(rng, model), model), | ||
params::Union{Nothing,Vector{<:AbstractFloat}}=nothing, | ||
reference_adtype::AbstractADType=REFERENCE_ADTYPE, | ||
expected_value_and_grad::Union{Nothing,Tuple{AbstractFloat,Vector{<:AbstractFloat}}}=nothing, | ||
verbose=true, | ||
)::ADResult | ||
# Convert Boolean `test` to an AbstractADCorrectnessTestSetting | ||
if test isa Bool | ||
test = test ? WithBackend() : NoTest() | ||
end | ||
|
||
# Extract parameters | ||
if isnothing(params) | ||
params = varinfo[:] | ||
end | ||
params = map(identity, params) # Concretise | ||
|
||
# Calculate log-density and gradient with the backend of interest | ||
verbose && @info "Running AD on $(model.f) with $(adtype)\n" | ||
verbose && println(" params : $(params)") | ||
ldf = LogDensityFunction(model, varinfo; adtype=adtype) | ||
|
||
value, grad = logdensity_and_gradient(ldf, params) | ||
# collect(): https://github.com/JuliaDiff/DifferentiationInterface.jl/issues/754 | ||
grad = collect(grad) | ||
verbose && println(" actual : $((value, grad))") | ||
|
||
if test | ||
# Calculate ground truth to compare against | ||
value_true, grad_true = if expected_value_and_grad === nothing | ||
ldf_reference = LogDensityFunction(model, varinfo; adtype=reference_adtype) | ||
logdensity_and_gradient(ldf_reference, params) | ||
else | ||
expected_value_and_grad | ||
# Test correctness | ||
if test isa NoTest | ||
value_true = nothing | ||
grad_true = nothing | ||
Comment on lines
+251
to
+253
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do these values actually ever get used? |
||
else | ||
# Get the correct result | ||
if test isa WithExpectedResult | ||
value_true = test.value | ||
grad_true = test.grad | ||
elseif test isa WithBackend | ||
ldf_reference = LogDensityFunction(model, varinfo; adtype=test.adtype) | ||
value_true, grad_true = logdensity_and_gradient(ldf_reference, params) | ||
# collect(): https://github.com/JuliaDiff/DifferentiationInterface.jl/issues/754 | ||
grad_true = collect(grad_true) | ||
end | ||
# Perform testing | ||
verbose && println(" expected : $((value_true, grad_true))") | ||
grad_true = collect(grad_true) | ||
|
||
exc() = throw(ADIncorrectException(value, value_true, grad, grad_true)) | ||
isapprox(value, value_true; atol=value_atol) || exc() | ||
isapprox(grad, grad_true; atol=grad_atol) || exc() | ||
else | ||
value_true = nothing | ||
grad_true = nothing | ||
isapprox(value, value_true; atol=atol, rtol=rtol) || exc() | ||
isapprox(grad, grad_true; atol=atol, rtol=rtol) || exc() | ||
end | ||
|
||
# Benchmark | ||
time_vs_primal = if benchmark | ||
primal_benchmark = @be (ldf, params) logdensity(_[1], _[2]) | ||
grad_benchmark = @be (ldf, params) logdensity_and_gradient(_[1], _[2]) | ||
|
@@ -237,8 +285,8 @@ | |
varinfo, | ||
params, | ||
adtype, | ||
value_atol, | ||
grad_atol, | ||
atol, | ||
rtol, | ||
value_true, | ||
grad_true, | ||
value, | ||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.