-
-
Notifications
You must be signed in to change notification settings - Fork 217
Description
At present, Zygote will use forward mode AD (outsourced to FowardDiff) under 2 circumstances:
- Based on a heuristic for broadcasting
- Upon an explicit call to
Zygote.forwarddiff
As shown by https://github.com/oschulz/ForwardDiffPullbacks.jl, there are a number of cases where being able to make an rrule actually run forward mode AD would be a boon for performance. One particularly salient example from Flux would be RNN pointwise broadcasts, which are currently unfused by Zygote for a massive compute + memory penalty. However, given we are simultaneously moving away from using Zygote-specific APIs downstream, defining rrule(::typeof(pointwise_op), xs...) = Zygote.forwarddiff(...)
is a non-starter. Hence, my proposal is to expose the standard frule_via_ad
so that downstream code can remain AD agnostic. Under the hood, this would work much the same as Zygote.forwarddiff
or ForwardDiffPullbacks.fwddiff
do now. It may even be possible to share some implementation details with one of those functions.
Note that this is not a request to make frule_via_ad
differentiable in reverse mode. Users would still be responsible for writing their own rrules, but one could imagine swapping out Zygote for Diffractor (which already implements frule_via_ad
) without making any code changes. Guarding on RuleConfig{>:HasForwardsMode}
would be enough to ensure compatibility with ADs which do not support forward mode.