Skip to content

Commit fb5b9c8

Browse files
authored
Rewrite the introduction (#510)
1 parent 99d56b1 commit fb5b9c8

File tree

4 files changed

+171
-117
lines changed

4 files changed

+171
-117
lines changed

docs/make.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ makedocs(;
4848
"Introduction" => "index.md",
4949
"How to use ChainRules as a rule author" => [
5050
"Introduction" => "rule_author/intro.md",
51-
"Differentials" => "rule_author/differentials.md",
51+
"Tangent types" => "rule_author/differentials.md",
5252
#"`frule` and `rrule`" => "rule_author/rules.md", # TODO: a complete example
5353
"Writing good rules" => "rule_author/writing_good_rules.md",
5454
"Testing your rules" => "rule_author/testing.md",
@@ -74,8 +74,9 @@ makedocs(;
7474
],
7575
"Design" => [
7676
"Changing the Primal" => "design/changing_the_primal.md",
77-
"Many Differential Types" => "design/many_differentials.md",
77+
"Many Tangent Types" => "design/many_differentials.md",
7878
],
79+
"Videos" => "videos.md",
7980
"FAQ" => "FAQ.md",
8081
"API" => "api.md",
8182
],

docs/src/index.md

Lines changed: 85 additions & 114 deletions
Original file line numberDiff line numberDiff line change
@@ -1,71 +1,109 @@
11
# ChainRules
22

3-
[ChainRules](https://github.com/JuliaDiff/ChainRules.jl) provides a variety of common utilities that can be used by downstream [automatic differentiation (AD)](https://en.wikipedia.org/wiki/Automatic_differentiation) tools to define and execute forward-, reverse-, and mixed-mode primitives.
3+
[Automatic differentiation (AD)](https://en.wikipedia.org/wiki/Automatic_differentiation) is a set of techniques for obtaining derivatives of arbitrary functions.
4+
There are surprisingly many packages for doing AD in Julia.
5+
ChainRules isn't one of these packages.
46

5-
## Introduction
7+
The AD packages essentially combine derivatives of simple functions into derivatives of more complicated functions.
8+
They differ in the way they break down complicated functions into simple ones, but they all require a common set of derivatives of simple functions (rules).
69

7-
ChainRules is all about providing a rich set of rules for differentiation.
8-
When a person learns introductory calculus, they learn that the derivative (with respect to `x`) of `a*x` is `a`, and the derivative of `sin(x)` is `cos(x)`, etc.
9-
And they learn how to combine simple rules, via [the chain rule](https://en.wikipedia.org/wiki/Chain_rule), to differentiate complicated functions.
10-
ChainRules is a programmatic repository of that knowledge, with the generalizations to higher dimensions.
10+
[ChainRules](https://github.com/JuliaDiff/ChainRules.jl) is an AD-independent set of rules, and a system for defining and testing rules.
1111

12-
[Autodiff (AD)](https://en.wikipedia.org/wiki/Automatic_differentiation) tools roughly work by reducing a problem down to simple parts that they know the rules for, and then combining those rules.
13-
Knowing rules for more complicated functions speeds up the autodiff process as it doesn't have to break things down as much.
12+
!!! note "What is a rule?"
13+
A rule encodes knowledge about propagating derivatives, e.g. that the derivative (with respect to `x`) of `a*x` is `a`, and the derivative of `sin(x)` is `cos(x)`, etc.
1414

15-
**ChainRules is an AD-independent collection of rules to use in a differentiation system.**
15+
## ChainRules ecosystem organisation
1616

17+
The ChainRules ecosystem comprises:
18+
- [ChainRulesCore.jl](https://github.com/JuliaDiff/ChainRulesCore.jl): a system for defining rules, and a collection of tangent types.
19+
- [ChainRules.jl](https://github.com/JuliaDiff/ChainRules.jl): a collection of rules for Julia Base and standard libraries.
20+
- [ChainRulesTestUtils.jl](https://github.com/JuliaDiff/ChainRulesTestUtils.jl): utilities for testing rules using finite differences.
1721

18-
!!! note "The whole field is a mess for terminology"
19-
It isn't just ChainRules, it is everyone.
20-
Internally ChainRules tries to be consistent.
21-
Help with that is always welcomed.
22+
AD systems depend on ChainRulesCore.jl to get access to tangent types and the core rule definition functionality (`frule` and `rrule`), and on ChainRules.jl to benefit from the collection of rules for Julia Base and the standard libraries.
2223

23-
!!! terminology "Primal"
24-
Often we will talk about something as _primal_.
25-
That means it is related to the original problem, not its derivative.
26-
For example in `y = foo(x)`, `foo` is the _primal_ function, and computing `foo(x)` is doing the _primal_ computation.
27-
`y` is the _primal_ return, and `x` is a _primal_ argument.
28-
`typeof(y)` and `typeof(x)` are both _primal_ types.
24+
Packages that just want to define rules only need to depend on [ChainRulesCore.jl](https://github.com/JuliaDiff/ChainRulesCore.jl), which is an exceptionally light dependency.
25+
They should also have a test-only dependency on [ChainRulesTestUtils.jl](https://github.com/JuliaDiff/ChainRulesTestUtils.jl) to test the rules using finite differences.
2926

27+
Note that the packages with rules do not have to depend on AD systems, and neither do the AD systems have to depend on individual packages.
3028

31-
## `frule` and `rrule`
29+
## Key functionality
30+
31+
Consider a relationship $y = f(x)$, where $f$ is some function.
32+
Computing $y$ from $x$ is the original problem, called the _primal_ computation, in contrast to the problem of computing derivatives.
33+
We say that the _primal function_ $f$ takes a _primal input_ $x$ and returns the _primal output_ $y$.
34+
35+
ChainRules rules are concerned with propagating _tangents_ of primal inputs to _tangents_ of primal outputs (`frule`, from forwards mode AD), and propagating _cotangents_ of primal outputs to _cotangents_ of primal inputs (`rrule`, from reverse mode AD).
36+
To be able to do that, ChainRules also defines a small number of tangent types to represent tangents and cotangents.
37+
38+
!!! note "Tangents and cotangents"
39+
Strictly speaking tangents, $ẋ = \frac{dx}{da}$, are propagated in `frule`s, and cotangents, $x̄ = \frac{da}{dx}$, are propagated in `rrule`s.
40+
However, in practice there is rarely a need to distinguish between the two: both are represented by the same tangent types.
41+
Thus, except when the detail might clarify, we refer to both as tangents.
3242

3343
!!! terminology "`frule` and `rrule`"
3444
`frule` and `rrule` are ChainRules specific terms.
3545
Their exact functioning is fairly ChainRules specific, though other tools have similar functions.
3646
The core notion is sometimes called _custom AD primitives_, _custom adjoints_, _custom gradients_, _custom sensitivities_.
47+
The whole field is a mess for terminology.
48+
49+
50+
### Forward-mode AD rules (`frule`s)
51+
52+
If we know the value of $ẋ = \frac{dx}{da}$ for some $a$ and we want to know $ẏ = \frac{dy}{da}$, the [chain rule](https://en.wikipedia.org/wiki/Chain_rule) tells us that $ẏ = \frac{dy}{dx} ẋ$.
53+
Intuitively, we are pushing the derivative forward.
54+
This is the basis for forward-mode AD.
55+
56+
!!! note "frule"
57+
The `frule` for $f$ encodes how to propagate the tangent of the primal input ($ẋ$) to the tangent of the primal output ($ẏ$).
58+
59+
The `frule` signature for a function `foo(args...; kwargs...)` is
60+
```julia
61+
function frule((Δself, Δargs...), ::typeof(foo), args...; kwargs...)
62+
...
63+
return y, ∂Y
64+
end
65+
```
66+
where `y = foo(args; kwargs...)` is the primal output, and `∂Y` is the result of propagating the input tangents `Δself`, `Δargs...` forwards at the point in the domain of `foo` described by `args`.
67+
This propagation is call the pushforward.
68+
Often we will think of the `frule` as having the primal computation `y = foo(args...; kwargs...)`, and the pushforward `∂Y = pushforward(Δself, Δargs...)`,
69+
even though they are not present in seperate forms in the code.
70+
71+
For example, the `frule` for `sin(x)` is:
72+
```julia
73+
function frule((_, Δx), ::typeof(sin), x)
74+
return sin(x), cos(x) * Δx
75+
end
76+
```
3777

38-
The rules are encoded as `frule`s and `rrule`s, for use in forward-mode and reverse-mode differentiation respectively.
78+
### Reverse-mode AD rules (`rrule`s)
3979

40-
The `rrule` for some function `foo`, which takes the positional arguments `args` and keyword arguments `kwargs`, is written:
80+
If we know the value of $ȳ = \frac{da}{dy}$ for some $a$ and we want to know $x̄ = \frac{da}{dx}$, the [chain rule](https://en.wikipedia.org/wiki/Chain_rule) tells us that $x̄ =ȳ \frac{dy}{dx}$.
81+
Intuitively, we are pushing the derivative backward.
82+
This is the basis for reverse-mode AD.
4183

84+
!!! note "rrule"
85+
The `rrule` for $f$ encodes how to propagate the cotangents of the primal output ($ȳ$) to the cotangent of the primal input ($x̄$).
86+
87+
The `rrule` signature for a function `foo(args...; kwargs...)` is
4288
```julia
4389
function rrule(::typeof(foo), args...; kwargs...)
4490
...
4591
return y, pullback
4692
end
4793
```
48-
where `y` (the primal result) must be equal to `foo(args...; kwargs...)`.
49-
`pullback` is a function to propagate the derivative information backwards at that point.
94+
where `y` (the primal output) must be equal to `foo(args...; kwargs...)`.
95+
`pullback` is a function to propagate the derivative information backwards at the point in the domain of `foo ` described by `args`.
5096
That pullback function is used like:
5197
`∂self, ∂args... = pullback(Δy)`
52-
53-
5498
Almost always the _pullback_ will be declared locally within the `rrule`, and will be a _closure_ over some of the other arguments, and potentially over the primal result too.
5599

56-
The `frule` is written:
100+
For example, the `rrule` for `sin(x)` is:
57101
```julia
58-
function frule((Δself, Δargs...), ::typeof(foo), args...; kwargs...)
59-
...
60-
return y, ∂Y
102+
function rrule(::typeof(sin), x)
103+
sin_pullback(Δy) = (NoTangent(), cos(x)' * Δy)
104+
return sin(x), sin_pullback
61105
end
62106
```
63-
where again `y = foo(args; kwargs...)`,
64-
and `∂Y` is the result of propagating the derivative information forwards at that point.
65-
This propagation is call the pushforward.
66-
Often we will think of the `frule` as having the primal computation `y = foo(args...; kwargs...)`, and the pushforward `∂Y = pushforward(Δself, Δargs...)`,
67-
even though they are not present in seperate forms in the code.
68-
69107

70108
!!! note "Why `rrule` returns a pullback but `frule` doesn't return a pushforward"
71109
While `rrule` takes only the arguments to the original function (the primal arguments) and returns a function (the pullback) that operates with the derivative information, the `frule` does it all at once.
@@ -75,89 +113,22 @@ even though they are not present in seperate forms in the code.
75113
In contrast, in reverse mode the derivative information needed by the pullback is about the primal function's output.
76114
Thus the reverse mode returns the pullback function which the caller (usually an AD system) keeps hold of until derivative information about the output is available.
77115

78-
## Videos
79-
80-
For people who learn better by video we have a number of videos of talks we have given about the ChainRules project.
81-
Note however, that the videos are frozen in time reflecting the state of the packages at the time they were recorded.
82-
This documentation is the continously updated canonical source.
83-
However, we have tried to note below each video notes on its correctness.
84-
85-
The talks that follow are in reverse chronological order (i.e. most recent video is first).
86-
87-
### EuroAD 2021: ChainRules.jl: AD system agnostic rules for JuliaLang
88-
Presented by Lyndon White.
89-
[Slides](https://www.slideshare.net/LyndonWhite2/euroad-2021-chainrulesjl)
90-
91-
This is the talk to watch if you want to understand why the ChainRules project exists, what its challenges are, and how those have been overcome.
92-
It is intended less for users of the package, and more for people working in the field of AD more generally.
93-
It does also serve as a nice motivation for those first coming across the package as well though.
94-
95-
```@raw html
96-
<div class="video-container">
97-
<iframe src="https://www.youtube.com/embed/B3bC49OmTdk" title="YouTube video player" frameborder="0" allow="accelerometer; autoplay; clipboard-write; encrypted-media; gyroscope; picture-in-picture" allowfullscreen></iframe>
98-
</div>
99-
```
100-
101-
Abstract:
102-
> The ChainRules project is a suite of JuliaLang packages that define custom primitives (i.e. rules) for doing AD in JuliaLang.
103-
> Importantly it is AD system agnostic.
104-
> It has proved successful in this goal.
105-
> At present it works with about half a dozen different JuliaLang AD systems.
106-
> It has been a long journey, but as of August 2021, the core packages have now hit version 1.0.
107-
>
108-
> This talk will go through why this is useful, the particular objectives the project had, and the challenges that had to be solved.
109-
> This talk is not intended as an educational guide for users (For that see our 2021 JuliaCon talk: > Everything you need to know about ChainRules 1.0 (https://live.juliacon.org/talk/LWVB39)).
110-
> Rather this talk is to share the insights we have had, and likely (inadvertently) the mistakes we have made, with the wider autodiff community.
111-
> We believe these insights can be informative and useful to efforts in other languages and ecosystems.
112-
113-
114-
### JuliaCon 2021: Everything you need to know about ChainRules 1.0
115-
Presented by Miha Zgubič.
116-
[Slides](https://github.com/mzgubic/ChainRulesTalk/blob/master/ChainRules.pdf)
117-
118-
If you are just wanting to watch a video to learn all about ChainRules and how to use it, watch this one.
119-
120-
!!! note "Slide on opting out is incorrect"
121-
Slide 42 is incorrect (`@no_rrule sum_array(A::Diagonal)`), in the ChainRulesCore 1.0 release the following syntax is used: `@opt_out rrule(::typeof(sum_array), A::Diagonal)`. This syntax allows us to include rule config information.
122-
123-
```@raw html
124-
<div class="video-container">
125-
<iframe src="https://www.youtube.com/embed/a8ol-1l84gc" title="YouTube video player" frameborder="0" allow="accelerometer; autoplay; clipboard-write; encrypted-media; gyroscope; picture-in-picture" allowfullscreen></iframe>
126-
</div>
127-
```
128-
129-
Abstract:
130-
> ChainRules is an automatic differentiation (AD)-independent ecosystem for forward-, reverse-, and mixed-mode primitives. It comprises ChainRules.jl, a collection of primitives for Julia Base, ChainRulesCore.jl, the utilities for defining custom primitives, and ChainRulesTestUtils.jl, the utilities to test primitives using finite differences. This talk provides brief updates on the ecosystem since last year and focuses on when and how to write and test custom primitives.
116+
### Tangent types
131117

118+
The types of (co)-tangents depend on the types of the primals.
119+
Scalar primals are represented by scalar tangents (e.g. `Float64` tangent for a `Float64` primal).
120+
Vector, matrix, and higher rank tensor primals can be represented by vector, matrix and tensor tangents.
132121

133-
### JuliaCon 2020: ChainRules.jl
134-
Presented by Lyndon White.
135-
[Slides](https://raw.githack.com/oxinabox/ChainRulesJuliaCon2020/main/out/build/index.html)
122+
ChainRules defines a [`Tangent`](@ref) tangent type to represent tangents of `struct`s, `Tuple`s, `NamedTuple`s, and `Dict`s.
136123

137-
This talk is primarily of historical interest.
138-
This was the first public presentation of ChainRules.
139-
Though the project was a few years old by this stage.
140-
A lot of things are still the same; conceptually, but a lot has changed.
141-
Most people shouldn't watch this talk now.
142-
143-
!!! warning "Outdated Terminology"
144-
A lot of terminology has changed since this presentation.
145-
- `DoesNotExist``NoTangent`
146-
- `Zero``ZeroTangent`
147-
- `Composite{P}``Tangent{T}`
148-
The talk also says differential in a lot of places where we now would say tangent.
149-
150-
```@raw html
151-
<div class="video-container">
152-
<iframe src="https://www.youtube.com/embed/B4NfkkkJ7rs" title="YouTube video player" frameborder="0" allow="accelerometer; autoplay; clipboard-write; encrypted-media; gyroscope; picture-in-picture" allowfullscreen></iframe>
153-
</div>
154-
```
124+
Additionally, for signalling semantics, we distinguish between two tangent types representing a zero tangent.
125+
[`NoTangent`](@ref) type represent situtations in which the tangent space does not exist, e.g. an index into an array can not be perturbed.
126+
[`ZeroTangent`](@ref) is used for cases where the tangent happens to be zero, e.g. because the primal argument is not used in the computation.
155127

156-
Abstract:
157-
> The ChainRules project allows package authors to write rules for custom sensitivities (sometimes called custom adjoints) in a way that is not dependent on any particular autodiff (AD) package.
158-
> It allows authors of AD packages to access a wealth of prewritten custom sensitivities, saving them the effort of writing them all out themselves.
159-
> ChainRules is the successor to DiffRules.jl and is the native rule system currently used by ForwardDiff2, Zygote and soon ReverseDiff
128+
We also define [`Thunk`](@ref)s to allow certain optimisation.
129+
`Thunk`s are a wrapper over a computation that can potentially be avoided, depending on the downstream use.
160130

131+
See the section on [tangent types](@ref tangents) for more details.
161132

162133
## Example of using ChainRules directly
163134

docs/src/rule_author/differentials.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Differentials
1+
# [Tangent types](@id tangents)
22

33
The values that come back from pullbacks or pushforwards are not always the same type as the input/outputs of the primal function.
44
They are differentials, which correspond roughly to something able to represent the difference between two values of the primal types.

0 commit comments

Comments
 (0)