Closed
Description
The conclusion of the discussion spread across #343, JuliaDiff/ChainRules.jl#337, #347, and JuliaDiff/ChainRules.jl#232 (and possibly elsewhere), has been to prefer writing abstractly typed rules.
Two shortcomings of this approach are that
- In some cases, such as when the primal function is specialised for a subtype (e.g.
Diagonal<:AbstractArray
), the fallback rrule returning aMatrix
is not correct. Solving that is the topic of Projecting Cotangents #286 - In other cases, using AD to differentiate through the specialised method is faster than falling back on the generic rule. For these, we need a way to opt out of using the generic fallback rule, and let AD compute the backward pass. Solving that is the topic of this issue.
There are a number of possible ways we could facilitate opting-out:
- Option A) Return
nothing
. This has been tried in Take nothing seriously FluxML/Zygote.jl#967 but advised against in Use ChainRules RuleConfig FluxML/Zygote.jl#990 (comment) for Zygote. According to @oxinabox this would work in Diffractor, but not in Nabla? (please just edit this if that's wrong) - Option B) Have a
no_rrule
function that would be used to designate a particular signature does not have a rule. - Option C) Have a
@no_rrule
macro which would take a signature and generate code which returnsnothing
(as above), and automatically add the method to an internal list, which would be accessed by Zygote and Nabla.
Have I missed another option we've discussed?