Skip to content

Implement frule_via_ad? #1222

@ToucheSir

Description

@ToucheSir

At present, Zygote will use forward mode AD (outsourced to FowardDiff) under 2 circumstances:

  1. Based on a heuristic for broadcasting
  2. Upon an explicit call to Zygote.forwarddiff

As shown by https://github.com/oschulz/ForwardDiffPullbacks.jl, there are a number of cases where being able to make an rrule actually run forward mode AD would be a boon for performance. One particularly salient example from Flux would be RNN pointwise broadcasts, which are currently unfused by Zygote for a massive compute + memory penalty. However, given we are simultaneously moving away from using Zygote-specific APIs downstream, defining rrule(::typeof(pointwise_op), xs...) = Zygote.forwarddiff(...) is a non-starter. Hence, my proposal is to expose the standard frule_via_ad so that downstream code can remain AD agnostic. Under the hood, this would work much the same as Zygote.forwarddiff or ForwardDiffPullbacks.fwddiff do now. It may even be possible to share some implementation details with one of those functions.

Note that this is not a request to make frule_via_ad differentiable in reverse mode. Users would still be responsible for writing their own rrules, but one could imagine swapping out Zygote for Diffractor (which already implements frule_via_ad) without making any code changes. Guarding on RuleConfig{>:HasForwardsMode} would be enough to ensure compatibility with ADs which do not support forward mode.

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions