-
Notifications
You must be signed in to change notification settings - Fork 65
Simplify arithmetic of NotImplemented
and treat NoTangent
like ZeroTangent
#477
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Codecov Report
@@ Coverage Diff @@
## master #477 +/- ##
==========================================
+ Coverage 92.89% 92.92% +0.02%
==========================================
Files 15 15
Lines 816 791 -25
==========================================
- Hits 758 735 -23
+ Misses 58 56 -2
Continue to review full report at Codecov.
|
@test_throws E @thunk(x^2) - ni | ||
for a in (rand(), NoTangent(), ZeroTangent(), true, @thunk(x^2)) | ||
@test_throws E ni - a | ||
@test_throws E a - ni |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems funky that adding is fine but subtracting throws.
Like I guess it's fine -- adding tangents is what happens in AD, subtracting vis not.
Subtracting happen in gradient descent.
Ok I am convinced this is the behaviour we want.
Do we have a comment saying as such?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll add a comment.
I think once this PR is merged, we should try to make use of the functionality in ChainRulesTestutils. Currently, Since it can't be inferred automatically which argument caused the |
This PR simplifies the arithmetic of
NotImplemented
. Mainly, it allows more operations that currently throw errors and it treatsNoTangent
likeZeroTangent
. The basic rules in this PR are:NotImplemented
always wins+
NotImplemented
loses*
anddot
againstNoTangent
andZeroTangent
(allows to ignore non-implemented partial derivatives) and wins*
anddot
against other types+
,*
, anddot
definitions forNotImplemented
and a value of a different type are commutativeThe main motivation for this PR is to be able to test rules such as JuliaMath/SpecialFunctions.jl#350 (comment) more easily with something like
(I checked that it works in my local branch with this PR). Without this PR, one has to specify a
ZeroTangent
to test the non-NotImplemented
derivative (without this PR, in all other cases aNotImplemented
is returned) - buttest_frule(besselix, nu ⊢ ZeroTangent(), x)
will includenu
in the finite differencing which causes test errors due to incorrect values and mismatching dimensions (ifx
is a complex number). Possibly (some of) these issues could be fixed in some other way, in ChainRulesTestutils or FiniteDifferences. However, I think it makes sense to treatNoTangent()
andZeroTangent()
in the same way when multiplying them with aNotImplemented
and hence the finite differencing problems can be avoided easily by markingnu
as non-differentiable withtest_frule(besselix, nu ⊢ NoTangent(), x)
.