Skip to content

Commit 0a01ae0

Browse files
authored
Add a pedagogical example (#516)
1 parent 1778091 commit 0a01ae0

File tree

6 files changed

+92
-3
lines changed

6 files changed

+92
-3
lines changed

docs/make.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ makedocs(;
5050
"Introduction" => "index.md",
5151
"How to use ChainRules as a rule author" => [
5252
"Introduction" => "rule_author/intro.md",
53+
"Pedagogical example" => "rule_author/example.md",
5354
"Tangent types" => "rule_author/tangents.md",
5455
#"`frule` and `rrule`" => "rule_author/rules.md", # TODO: a complete example
5556
"Writing good rules" => "rule_author/writing_good_rules.md",

docs/src/index.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ They differ in the way they break down complicated functions into simple ones, b
1010
[ChainRules](https://github.com/JuliaDiff/ChainRules.jl) is an AD-independent set of rules, and a system for defining and testing rules.
1111

1212
!!! note "What is a rule?"
13-
A rule encodes knowledge about propagating derivatives, e.g. that the derivative (with respect to `x`) of `a*x` is `a`, and the derivative of `sin(x)` is `cos(x)`, etc.
13+
A rule encodes knowledge about propagating derivatives, e.g. that the derivative with respect to `x` of `a*x` is `a`, and the derivative of `sin(x)` is `cos(x)`, etc.
1414

1515
## ChainRules ecosystem organisation
1616

docs/src/rule_author/example.md

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
# Pedagogical Example
2+
3+
This pedagogical example will show you how to write an `rrule`.
4+
See [On writing good `rrule` / `frule` methods](@ref) section for more tips and gotchas.
5+
If you want to learn about `frule`s, you should still read and understand this example as many concepts are shared, and then look for real world `frule` examples in ChainRules.jl.
6+
7+
## The primal
8+
9+
We define a struct `Foo`
10+
```julia
11+
struct Foo
12+
A::Matrix
13+
c::Float64
14+
end
15+
```
16+
and a function that multiplies `Foo` with an `AbstractArray`:
17+
```julia
18+
function foo_mul(foo::Foo, b::AbstractArray)
19+
return foo.A * b
20+
end
21+
```
22+
Note that field `c` is ignored in the calculation.
23+
24+
## The `rrule`
25+
26+
The `rrule` method for our primal computation should extend the `ChainRulesCore.rrule` function.
27+
```julia
28+
function ChainRulesCore.rrule(::typeof(foo_mul), foo::Foo, b::AbstractArray)
29+
y = foo_mul(foo, b)
30+
function foo_mul_pullback(ȳ)
31+
= NoTangent()
32+
f̄oo = Tangent{Foo}(; A=* b', c=ZeroTangent())
33+
= @thunk(foo.A' * ȳ)
34+
return f̄, f̄oo, b̄
35+
end
36+
return y, foo_mul_pullback
37+
end
38+
```
39+
Now let's examine the rule in more detail:
40+
```julia
41+
function ChainRulesCore.rrule(::typeof(foo_mul), foo::Foo, b::AbstractArray)
42+
...
43+
return y, foo_mul_pullback
44+
end
45+
```
46+
The `rrule` dispatches on the `typeof` of the function we are writing the `rrule` for, as well as the types of its arguments.
47+
Read more about writing rules for constructors and callable objects [here](@ref structs).
48+
The `rrule` returns the primal result `y`, and the pullback function.
49+
It is a _very_ good idea to name your pullback function, so that they are helpful when appearing in the stacktrace.
50+
```julia
51+
y = foo_mul(foo, b)
52+
```
53+
Computes the primal result.
54+
It is possible to change the primal computation so that work can be shared between the primal and the pullback.
55+
See e.g. [the rule for `sort`](https://github.com/JuliaDiff/ChainRules.jl/blob/a75193768775975fac5578c89d1e5f50d7f358c2/src/rulesets/Base/sort.jl#L19-L35), where the sorting is done only once.
56+
```julia
57+
function foo_mul_pullback(ȳ)
58+
...
59+
return f̄, f̄oo, b̄
60+
end
61+
```
62+
The pullback function takes in the tangent of the primal output (``) and returns the tangents of the primal inputs.
63+
Note that it returns a tangent for the primal function in addition to the tangents of primal arguments.
64+
65+
Finally, computing the tangents of primal inputs:
66+
```julia
67+
= NoTangent()
68+
```
69+
The function `foo_mul` has no fields (i.e. it is not a closure) and can not be perturbed.
70+
Therefore its tangent (``) is a `NoTangent`.
71+
```julia
72+
f̄oo = Tangent{Foo}(; A=* b', c=ZeroTangent())
73+
```
74+
The struct `foo::Foo` gets a `Tangent{Foo}` structural tangent, which stores the tangents of fields of `foo`.
75+
76+
The tangent of the field `A` is `ȳ * b'`,
77+
78+
The tangent of the field `c` is `ZeroTangent()`, because `c` can be perturbed but has no effect on the primal output.
79+
```julia
80+
= @thunk(foo.A' * ȳ)
81+
```
82+
The tangent of `b` is `foo.A' * ȳ`, but we have wrapped it into a `Thunk`, a tangent type that represents delayed computation.
83+
The idea is that in case the tangent is not used anywhere, the computation never happens.
84+
Use [`InplaceableThunk`](@ref) if you are interested in [accumulating gradients inplace](@ref grad_acc).
85+
Note that in practice one would also `@thunk` the `f̄oo.A` tangent, but it was omitted in this example for clarity.
86+
87+
As a final note, Since `b` is an `AbstractArray`, its tangent `` should be projected to the right subspace.
88+
See the [`ProjectTo` the primal subspace](@ref projectto) section for more information and an example that motivates the projection operation.

docs/src/rule_author/intro.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,6 @@ This section also outlines some ChainRules superpowers that can be considered ad
1010
Most users can ignore these.
1111
However:
1212
- If you are writing rules with abstractly typed arguments, read about [`ProjectTo`](@ref projectto).
13-
- If you want to opt out of using the abstractly typed rule for certain argument types, read [`@opt_out`](@ref opt_out).
13+
- If you want to opt out of using the abstractly typed rule for certain argument types, read about [`@opt_out`](@ref opt_out).
1414
- If you are writing rules for higher order functions, read about [calling back into AD](@ref config).
1515
- If you want to accumulate gradients inplace to avoid extra allocations, read about [gradient accumulation](@ref grad_acc).

docs/src/rule_author/rules.md

Whitespace-only changes.

docs/src/rule_author/writing_good_rules.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ Examples being:
7171
- There is only one derivative being returned, so from the fact that the user called
7272
`frule`/`rrule` they clearly will want to use that one.
7373

74-
## Structs: constructors and functors
74+
## [Structs: constructors and functors](@id structs)
7575

7676
To define an `frule` or `rrule` for a _function_ `foo` we dispatch on the type of `foo`, which is `typeof(foo)`.
7777
For example, the `rrule` signature would be like:

0 commit comments

Comments
 (0)