-
Notifications
You must be signed in to change notification settings - Fork 35
Improve API for AD testing #964
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: breaking
Are you sure you want to change the base?
Conversation
Benchmark Report for Commit be36626Computer Information
Benchmark Results
|
bda0ea4
to
4ce84c2
Compare
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## breaking #964 +/- ##
============================================
- Coverage 82.85% 82.67% -0.18%
============================================
Files 38 38
Lines 4031 4018 -13
============================================
- Hits 3340 3322 -18
- Misses 691 696 +5 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
DynamicPPL.jl documentation for PR #964 is available at: |
20f4d56
to
3587ce5
Compare
Let me know if this is ready for review. |
Oops, I forgot about this one. Yeah I think it should be |
@@ -211,6 +211,21 @@ To test and/or benchmark the performance of an AD backend on a model, DynamicPPL | |||
|
|||
```@docs | |||
DynamicPPL.TestUtils.AD.run_ad | |||
``` | |||
|
|||
THe default test setting is to compare against ForwardDiff. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
THe default test setting is to compare against ForwardDiff. | |
The default test setting is to compare against ForwardDiff. |
`adtype` defaults to ForwardDiff.jl, since it's the default AD backend used in | ||
Turing.jl. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Have we considered using FiniteDifferences.jl instead?
src/test_utils/ad.jl
Outdated
|
||
The tolerances for the value and gradient can be set using `value_atol` and | ||
`grad_atol`. These default to 1e-6. | ||
Note that gradients are always compared elementwise. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As discussed, this should probably change.
if test isa NoTest | ||
value_true = nothing | ||
grad_true = nothing |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do these values actually ever get used?
This PR makes some fairly overdue improvements to the API of
DynamicPPL.TestUtils.AD.run_ad
.rng
argumentA
rng
keyword argument is provided to make it easier to seed the parameters used for running AD. Previously, if you wanted to make sure that two calls torun_ad
used the same parameters (but you didn't care what parameters they were, just that they were the same!), you had to do:Now you can do:
I made
rng
a keyword argument rather than a positional argument because I considerrng
-as-first-argument to be a multiple dispatch abuse anti-pattern, i.e., it serves no purpose except to force someone to declare a new method.Closes #962
Correctness testing
Previously the
test
,reference_backend
, andexpected_value_and_grad
keyword arguments all served the same purpose and it was not clear when one would supersede the other (e.g. if you puttest=false
,reference_backend=AutoForwardDiff()
, andexpected_value_and_grad=(value, grad)
it was unclear whether it would skip testing, compare against ForwardDiff, or compare against the explicitly specified values).I originally made this design choice to avoid having to make my own types (which would be some boilerplate and annoying for downstream users to import), but over the course of using this function (especially in ADTests) I have found the annoyance of not having a clear API to be bigger than the annoyance of adding some more imports.
This PR fixes it so that you can't specify multiple of these at the same time.
Tolerances
Previously, it was only possible to specify the
atol
used for testing;rtol
could not be configured (and would default to zero, because that's whatisapprox
does when given a non-zero atol). This caused problems such as TuringLang/ADTests#33. This PR fixes it such that testing happens with nonzero atol and rtol (which can both be configured).Closes #963