|
2 | 2 | ##### `frule`/`rrule`
|
3 | 3 | #####
|
4 | 4 |
|
5 |
| -#= |
6 |
| -In some weird ideal sense, the fallback for e.g. `frule` should actually be "get |
7 |
| -the derivative via forward-mode AD". This is necessary to enable mixed-mode |
8 |
| -rules, where e.g. `frule` is used within a `rrule` definition. For example, |
9 |
| -broadcasted functions may not themselves be forward-mode *primitives*, but are |
10 |
| -often forward-mode *differentiable*. |
11 |
| -
|
12 |
| -ChainRulesCore, by design, is decoupled from any specific AD implementation. How, |
13 |
| -then, do we know which AD to fall back to when there isn't a primitive defined? |
14 |
| -
|
15 |
| -Well, if you're a greedy AD implementation, you can just overload `frule` and/or |
16 |
| -`rrule` to use your AD directly. However, this won't play nice with other AD |
17 |
| -packages doing the same thing, and thus could cause load-order-dependent |
18 |
| -problems for downstream users. |
19 |
| -
|
20 |
| -It turns out, Cassette solves this problem nicely by allowing AD authors to |
21 |
| -overload the fallbacks w.r.t. their own context. Example using ForwardDiff: |
22 |
| -
|
23 |
| -``` |
24 |
| -using ChainRulesCore, ForwardDiff, Cassette |
25 |
| -
|
26 |
| -Cassette.@context MyChainRuleCtx |
27 |
| -
|
28 |
| -# ForwardDiff, itself, can call `my_frule` instead of |
29 |
| -# `frule` to utilize the ForwardDiff-injected ChainRulesCore |
30 |
| -# infrastructure |
31 |
| -my_frule(args...) = Cassette.overdub(MyChainRuleCtx(), frule, args...) |
32 |
| -
|
33 |
| -function Cassette.execute(::MyChainRuleCtx, ::typeof(frule), f, x::Number) |
34 |
| - r = frule(f, x) |
35 |
| - if isa(r, Nothing) |
36 |
| - fx, df = (f(x), (_, Δx) -> ForwardDiff.derivative(f, x) * Δx) |
37 |
| - else |
38 |
| - fx, df = r |
39 |
| - end |
40 |
| - return fx, df |
41 |
| -end |
42 |
| -``` |
43 |
| -=# |
44 |
| - |
45 | 5 | """
|
46 | 6 | frule(f, x...)
|
47 | 7 |
|
|
0 commit comments