Skip to content

Commit 62d3be3

Browse files
committed
remove reference to Rules from docstrings
1 parent 765ecfc commit 62d3be3

File tree

3 files changed

+23
-38
lines changed

3 files changed

+23
-38
lines changed

src/operations.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@ storing the result in `Δ`.
1919
Note: this function may not actually store the result in `Δ` if `Δ` is immutable,
2020
so it is best to always call this as `Δ = accumulate!(Δ, ∂)` just in-case.
2121
22-
This function is overloadable by [`InplaceableThunk`s](@ref).
23-
See also: [`accumulate`](@ref), [`store!`](@ref), [`AbstractRule`](@ref)
22+
This function is overloadable by using a [`InplaceThunk`](@ref).
23+
See also: [`accumulate`](@ref), [`store!`](@ref).
2424
"""
2525
function accumulate!(Δ, ∂)
2626
return materialize!(Δ, broadcastable(cast(Δ) + ∂))

src/rule_definition_tools.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ is equivalent to:
151151
152152
For examples, see ChainRulesCore' `rules` directory.
153153
154-
See also: [`frule`](@ref), [`rrule`](@ref), [`AbstractRule`](@ref)
154+
See also: [`frule`](@ref), [`rrule`](@ref).
155155
"""
156156
macro scalar_rule(call, maybe_setup, partials...)
157157
############################################################################

src/rules.jl

Lines changed: 20 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
#=
66
In some weird ideal sense, the fallback for e.g. `frule` should actually be "get
7-
the derivative via forward-mode AD". This is necessary to enable mixed-mode
7+
the derivative via forward-ode AD". This is necessary to enable mixed-mode
88
rules, where e.g. `frule` is used within a `rrule` definition. For example,
99
broadcasted functions may not themselves be forward-mode *primitives*, but are
1010
often forward-mode *differentiable*.
@@ -33,7 +33,7 @@ my_frule(args...) = Cassette.overdub(MyChainRuleCtx(), frule, args...)
3333
function Cassette.execute(::MyChainRuleCtx, ::typeof(frule), f, x::Number)
3434
r = frule(f, x)
3535
if isa(r, Nothing)
36-
fx, df = (f(x), Rule(Δx -> ForwardDiff.derivative(f, x) * Δx))
36+
fx, df = (f(x), (_, Δx) -> ForwardDiff.derivative(f, x) * Δx)
3737
else
3838
fx, df = r
3939
end
@@ -48,16 +48,12 @@ end
4848
Expressing `x` as the tuple `(x₁, x₂, ...)` and the output tuple of `f(x...)`
4949
as `Ω`, return the tuple:
5050
51-
(Ω, (rule_for_ΔΩ₁::AbstractRule, rule_for_ΔΩ₂::AbstractRule, ...))
51+
(Ω, (ṡelf, ẋ₁, ẋ₂, ...) -> Ω̇₁, Ω̇₂, ...)
5252
53-
where each returned propagation rule `rule_for_ΔΩᵢ` can be invoked as
53+
The second return value is the propagation rule, or the pushforward.
54+
It takes in differentials corresponding to the inputs (`ẋ₁, ẋ₂, ...`)
55+
and `ṡelf` the internal values of the function (for closures).
5456
55-
rule_for_ΔΩᵢ(Δx₁, Δx₂, ...)
56-
57-
to yield `Ωᵢ`'s corresponding differential `ΔΩᵢ`. To illustrate, if all involved
58-
values are real-valued scalars, this differential can be written as:
59-
60-
ΔΩᵢ = ∂Ωᵢ_∂x₁ * Δx₁ + ∂Ωᵢ_∂x₂ * Δx₂ + ...
6157
6258
If no method matching `frule(f, xs...)` has been defined, then return `nothing`.
6359
@@ -68,12 +64,12 @@ unary input, unary output scalar function:
6864
```
6965
julia> x = rand();
7066
71-
julia> sinx, dsin = frule(sin, x);
67+
julia> sinx, sin_pushforward = frule(sin, x);
7268
7369
julia> sinx == sin(x)
7470
true
7571
76-
julia> dsin(1) == cos(x)
72+
julia> sin_pushforward(NamedTuple(), 1) == cos(x)
7773
true
7874
```
7975
@@ -82,19 +78,16 @@ unary input, binary output scalar function:
8278
```
8379
julia> x = rand();
8480
85-
julia> sincosx, (dsin, dcos) = frule(sincos, x);
81+
julia> sincosx, sincos_pushforward = frule(sincos, x);
8682
8783
julia> sincosx == sincos(x)
8884
true
8985
90-
julia> dsin(1) == cos(x)
91-
true
92-
93-
julia> dcos(1) == -sin(x)
86+
julia> sincos_pushforward(NamedTuple(), 1) == (cos(x), -sin(x))
9487
true
9588
```
9689
97-
See also: [`rrule`](@ref), [`AbstractRule`](@ref), [`@scalar_rule`](@ref)
90+
See also: [`rrule`](@ref), [`@scalar_rule`](@ref)
9891
"""
9992
frule(::Any, ::Vararg{Any}; kwargs...) = nothing
10093

@@ -104,16 +97,11 @@ frule(::Any, ::Vararg{Any}; kwargs...) = nothing
10497
Expressing `x` as the tuple `(x₁, x₂, ...)` and the output tuple of `f(x...)`
10598
as `Ω`, return the tuple:
10699
107-
(Ω, (rule_for_Δx₁::AbstractRule, rule_for_Δx₂::AbstractRule, ...))
108-
109-
where each returned propagation rule `rule_for_Δxᵢ` can be invoked as
100+
(Ω, (Ω̄₁, Ω̄₂, ...) -> (s̄elf, x̄₁, x̄₂, ...))
110101
111-
rule_for_Δxᵢ(ΔΩ₁, ΔΩ₂, ...)
112-
113-
to yield `xᵢ`'s corresponding differential `Δxᵢ`. To illustrate, if all involved
114-
values are real-valued scalars, this differential can be written as:
115-
116-
Δxᵢ = ∂Ω₁_∂xᵢ * ΔΩ₁ + ∂Ω₂_∂xᵢ * ΔΩ₂ + ...
102+
Where the second return value is the the propagation rule or pullback.
103+
It takes in differentials corresponding to the outputs (`x̄₁, x̄₂, ...`),
104+
and `s̄elf`, the internal values of the function itself (for closures)
117105
118106
If no method matching `rrule(f, xs...)` has been defined, then return `nothing`.
119107
@@ -124,12 +112,12 @@ unary input, unary output scalar function:
124112
```
125113
julia> x = rand();
126114
127-
julia> sinx, dx = rrule(sin, x);
115+
julia> sinx, sin_pullback = rrule(sin, x);
128116
129117
julia> sinx == sin(x)
130118
true
131119
132-
julia> dx(1) == cos(x)
120+
julia> sin_pullback(1) == (NO_FIELDS, cos(x))
133121
true
134122
```
135123
@@ -138,18 +126,15 @@ binary input, unary output scalar function:
138126
```
139127
julia> x, y = rand(2);
140128
141-
julia> hypotxy, (dx, dy) = rrule(hypot, x, y);
129+
julia> hypotxy, hypot_pullback = rrule(hypot, x, y);
142130
143131
julia> hypotxy == hypot(x, y)
144132
true
145133
146-
julia> dx(1) == (x / hypot(x, y))
147-
true
148-
149-
julia> dy(1) == (y / hypot(x, y))
134+
julia> hypot_pullback(1) == (NO_FIELDS, (x / hypot(x, y)), (y / hypot(x, y)))
150135
true
151136
```
152137
153-
See also: [`frule`](@ref), [`AbstractRule`](@ref), [`@scalar_rule`](@ref)
138+
See also: [`frule`](@ref), [`@scalar_rule`](@ref)
154139
"""
155140
rrule(::Any, ::Vararg{Any}; kwargs...) = nothing

0 commit comments

Comments
 (0)