Skip to content

Simplify arithmetic of NotImplemented and treat NoTangent like ZeroTangent #477

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Oct 3, 2021

Conversation

devmotion
Copy link
Member

This PR simplifies the arithmetic of NotImplemented. Mainly, it allows more operations that currently throw errors and it treats NoTangent like ZeroTangent. The basic rules in this PR are:

  • NotImplemented always wins +
  • NotImplemented loses * and dot against NoTangent and ZeroTangent (allows to ignore non-implemented partial derivatives) and wins * and dot against other types
  • +, *, and dot definitions for NotImplemented and a value of a different type are commutative

The main motivation for this PR is to be able to test rules such as JuliaMath/SpecialFunctions.jl#350 (comment) more easily with something like

                test_frule(besselix, nu, x) # derivative is `NotImplemented`
                test_frule(besselix, nu  NoTangent(), x) # derivative is a number

(I checked that it works in my local branch with this PR). Without this PR, one has to specify a ZeroTangent to test the non-NotImplemented derivative (without this PR, in all other cases a NotImplemented is returned) - but test_frule(besselix, nu ⊢ ZeroTangent(), x) will include nu in the finite differencing which causes test errors due to incorrect values and mismatching dimensions (if x is a complex number). Possibly (some of) these issues could be fixed in some other way, in ChainRulesTestutils or FiniteDifferences. However, I think it makes sense to treat NoTangent() and ZeroTangent() in the same way when multiplying them with a NotImplemented and hence the finite differencing problems can be avoided easily by marking nu as non-differentiable with test_frule(besselix, nu ⊢ NoTangent(), x).

@codecov-commenter
Copy link

codecov-commenter commented Oct 2, 2021

Codecov Report

Merging #477 (ea5c889) into master (d58e420) will increase coverage by 0.02%.
The diff coverage is 100.00%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master     #477      +/-   ##
==========================================
+ Coverage   92.89%   92.92%   +0.02%     
==========================================
  Files          15       15              
  Lines         816      791      -25     
==========================================
- Hits          758      735      -23     
+ Misses         58       56       -2     
Impacted Files Coverage Δ
src/tangent_arithmetic.jl 96.42% <100.00%> (+0.09%) ⬆️
src/tangent_types/abstract_tangent.jl 100.00% <0.00%> (+50.00%) ⬆️

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update d58e420...ea5c889. Read the comment docs.

@test_throws E @thunk(x^2) - ni
for a in (rand(), NoTangent(), ZeroTangent(), true, @thunk(x^2))
@test_throws E ni - a
@test_throws E a - ni
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems funky that adding is fine but subtracting throws.

Like I guess it's fine -- adding tangents is what happens in AD, subtracting vis not.

Subtracting happen in gradient descent.

Ok I am convinced this is the behaviour we want.
Do we have a comment saying as such?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll add a comment.

@devmotion
Copy link
Member Author

I think once this PR is merged, we should try to make use of the functionality in ChainRulesTestutils. Currently, test_frule is pretty useless if one of the partial derivatives is a NotImplemented since the derivative will always be a NotImplemented and hence tests always pass (even though marked as broken), regardless of the other derivatives. With this PR it is possible to test the other partial derivatives by setting the tangents of the arguments for which the partial derivative is a NotImplemented to NoTangent() to test these partial derivatives properly. However, currently one has to do this manually, as in https://github.com/JuliaMath/SpecialFunctions.jl/blob/0af956882245e3b07340002c7c95c319e51af52a/test/chainrules.jl#L56-L57. It would be nice if these proper tests would be performed automatically.

Since it can't be inferred automatically which argument caused the NotImplemented derivative one approach might be to force users to always specify a NoTangent() tangent in the frule tests if the partial derivative is NotImplemented (if this is not already the default rand_tangent, in which case the partial derivative probably should just be NoTangent()), and basically not allow NotImplemented derivatives in the tests. A problem might be that this hides the fact that the implementation is broken and should be fixed in the frule tests. Maybe it would be better to let users specify a tangent of type NotImplemented (e.g. with @not_implemented()) and then use a NoTangent() internally but mark the tests as broken?

@devmotion devmotion merged commit 834901e into master Oct 3, 2021
@devmotion devmotion deleted the dw/notimplemented_arithmetic branch October 3, 2021 10:06
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants