4
4
5
5
#=
6
6
In some weird ideal sense, the fallback for e.g. `frule` should actually be "get
7
- the derivative via forward-mode AD". This is necessary to enable mixed-mode
7
+ the derivative via forward-ode AD". This is necessary to enable mixed-mode
8
8
rules, where e.g. `frule` is used within a `rrule` definition. For example,
9
9
broadcasted functions may not themselves be forward-mode *primitives*, but are
10
10
often forward-mode *differentiable*.
@@ -33,7 +33,7 @@ my_frule(args...) = Cassette.overdub(MyChainRuleCtx(), frule, args...)
33
33
function Cassette.execute(::MyChainRuleCtx, ::typeof(frule), f, x::Number)
34
34
r = frule(f, x)
35
35
if isa(r, Nothing)
36
- fx, df = (f(x), Rule(Δx -> ForwardDiff.derivative(f, x) * Δx) )
36
+ fx, df = (f(x), (_, Δx) -> ForwardDiff.derivative(f, x) * Δx)
37
37
else
38
38
fx, df = r
39
39
end
48
48
Expressing `x` as the tuple `(x₁, x₂, ...)` and the output tuple of `f(x...)`
49
49
as `Ω`, return the tuple:
50
50
51
- (Ω, (rule_for_ΔΩ₁::AbstractRule, rule_for_ΔΩ₂::AbstractRule , ...))
51
+ (Ω, (ṡelf, ẋ₁, ẋ₂ , ...) -> Ω̇₁, Ω̇₂, ... )
52
52
53
- where each returned propagation rule `rule_for_ΔΩᵢ` can be invoked as
53
+ The second return value is the propagation rule, or the pushforward.
54
+ It takes in differentials corresponding to the inputs (`ẋ₁, ẋ₂, ...`)
55
+ and `ṡelf` the internal values of the function (for closures).
54
56
55
- rule_for_ΔΩᵢ(Δx₁, Δx₂, ...)
56
-
57
- to yield `Ωᵢ`'s corresponding differential `ΔΩᵢ`. To illustrate, if all involved
58
- values are real-valued scalars, this differential can be written as:
59
-
60
- ΔΩᵢ = ∂Ωᵢ_∂x₁ * Δx₁ + ∂Ωᵢ_∂x₂ * Δx₂ + ...
61
57
62
58
If no method matching `frule(f, xs...)` has been defined, then return `nothing`.
63
59
@@ -68,12 +64,12 @@ unary input, unary output scalar function:
68
64
```
69
65
julia> x = rand();
70
66
71
- julia> sinx, dsin = frule(sin, x);
67
+ julia> sinx, sin_pushforward = frule(sin, x);
72
68
73
69
julia> sinx == sin(x)
74
70
true
75
71
76
- julia> dsin( 1) == cos(x)
72
+ julia> sin_pushforward(NamedTuple(), 1) == cos(x)
77
73
true
78
74
```
79
75
@@ -82,19 +78,16 @@ unary input, binary output scalar function:
82
78
```
83
79
julia> x = rand();
84
80
85
- julia> sincosx, (dsin, dcos) = frule(sincos, x);
81
+ julia> sincosx, sincos_pushforward = frule(sincos, x);
86
82
87
83
julia> sincosx == sincos(x)
88
84
true
89
85
90
- julia> dsin(1) == cos(x)
91
- true
92
-
93
- julia> dcos(1) == -sin(x)
86
+ julia> sincos_pushforward(NamedTuple(), 1) == (cos(x), -sin(x))
94
87
true
95
88
```
96
89
97
- See also: [`rrule`](@ref), [`AbstractRule`](@ref), [` @scalar_rule`](@ref)
90
+ See also: [`rrule`](@ref), [`@scalar_rule`](@ref)
98
91
"""
99
92
frule (:: Any , :: Vararg{Any} ; kwargs... ) = nothing
100
93
@@ -104,16 +97,11 @@ frule(::Any, ::Vararg{Any}; kwargs...) = nothing
104
97
Expressing `x` as the tuple `(x₁, x₂, ...)` and the output tuple of `f(x...)`
105
98
as `Ω`, return the tuple:
106
99
107
- (Ω, (rule_for_Δx₁::AbstractRule, rule_for_Δx₂::AbstractRule, ...))
108
-
109
- where each returned propagation rule `rule_for_Δxᵢ` can be invoked as
100
+ (Ω, (Ω̄₁, Ω̄₂, ...) -> (s̄elf, x̄₁, x̄₂, ...))
110
101
111
- rule_for_Δxᵢ(ΔΩ₁, ΔΩ₂, ...)
112
-
113
- to yield `xᵢ`'s corresponding differential `Δxᵢ`. To illustrate, if all involved
114
- values are real-valued scalars, this differential can be written as:
115
-
116
- Δxᵢ = ∂Ω₁_∂xᵢ * ΔΩ₁ + ∂Ω₂_∂xᵢ * ΔΩ₂ + ...
102
+ Where the second return value is the the propagation rule or pullback.
103
+ It takes in differentials corresponding to the outputs (`x̄₁, x̄₂, ...`),
104
+ and `s̄elf`, the internal values of the function itself (for closures)
117
105
118
106
If no method matching `rrule(f, xs...)` has been defined, then return `nothing`.
119
107
@@ -124,12 +112,12 @@ unary input, unary output scalar function:
124
112
```
125
113
julia> x = rand();
126
114
127
- julia> sinx, dx = rrule(sin, x);
115
+ julia> sinx, sin_pullback = rrule(sin, x);
128
116
129
117
julia> sinx == sin(x)
130
118
true
131
119
132
- julia> dx (1) == cos(x)
120
+ julia> sin_pullback (1) == (NO_FIELDS, cos(x) )
133
121
true
134
122
```
135
123
@@ -138,18 +126,15 @@ binary input, unary output scalar function:
138
126
```
139
127
julia> x, y = rand(2);
140
128
141
- julia> hypotxy, (dx, dy) = rrule(hypot, x, y);
129
+ julia> hypotxy, hypot_pullback = rrule(hypot, x, y);
142
130
143
131
julia> hypotxy == hypot(x, y)
144
132
true
145
133
146
- julia> dx(1) == (x / hypot(x, y))
147
- true
148
-
149
- julia> dy(1) == (y / hypot(x, y))
134
+ julia> hypot_pullback(1) == (NO_FIELDS, (x / hypot(x, y)), (y / hypot(x, y)))
150
135
true
151
136
```
152
137
153
- See also: [`frule`](@ref), [`AbstractRule`](@ref), [` @scalar_rule`](@ref)
138
+ See also: [`frule`](@ref), [`@scalar_rule`](@ref)
154
139
"""
155
140
rrule (:: Any , :: Vararg{Any} ; kwargs... ) = nothing
0 commit comments