You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
[ChainRules](https://github.com/JuliaDiff/ChainRules.jl) provides a variety of common utilities that can be used by downstream [automatic differentiation (AD)](https://en.wikipedia.org/wiki/Automatic_differentiation) tools to define and execute forward-, reverse-, and mixed-mode primitives.
3
+
[Automatic differentiation (AD)](https://en.wikipedia.org/wiki/Automatic_differentiation) is a set of techniques for obtaining derivatives of arbitrary functions.
4
+
There are surprisingly many packages for doing AD in Julia.
5
+
ChainRules isn't one of these packages.
4
6
5
-
## Introduction
7
+
The AD packages essentially combine derivatives of simple functions into derivatives of more complicated functions.
8
+
They differ in the way they break down complicated functions into simple ones, but they all require a common set of derivatives of simple functions (rules).
6
9
7
-
ChainRules is all about providing a rich set of rules for differentiation.
8
-
When a person learns introductory calculus, they learn that the derivative (with respect to `x`) of `a*x` is `a`, and the derivative of `sin(x)` is `cos(x)`, etc.
9
-
And they learn how to combine simple rules, via [the chain rule](https://en.wikipedia.org/wiki/Chain_rule), to differentiate complicated functions.
10
-
ChainRules is a programmatic repository of that knowledge, with the generalizations to higher dimensions.
10
+
[ChainRules](https://github.com/JuliaDiff/ChainRules.jl) is an AD-independent set of rules, and a system for defining and testing rules.
11
11
12
-
[Autodiff (AD)](https://en.wikipedia.org/wiki/Automatic_differentiation) tools roughly work by reducing a problem down to simple parts that they know the rules for, and then combining those rules.
13
-
Knowing rules for more complicated functions speeds up the autodiff process as it doesn't have to break things down as much.
12
+
!!! note "What is a rule?"
13
+
A rule encodes knowledge about propagating derivatives, e.g. that the derivative (with respect to `x`) of `a*x` is `a`, and the derivative of `sin(x)` is `cos(x)`, etc.
14
14
15
-
**ChainRules is an AD-independent collection of rules to use in a differentiation system.**
15
+
## ChainRules ecosystem organisation
16
16
17
+
The ChainRules ecosystem comprises:
18
+
-[ChainRulesCore.jl](https://github.com/JuliaDiff/ChainRulesCore.jl): a system for defining rules, and a collection of tangent types.
19
+
-[ChainRules.jl](https://github.com/JuliaDiff/ChainRules.jl): a collection of rules for Julia Base and standard libraries.
20
+
-[ChainRulesTestUtils.jl](https://github.com/JuliaDiff/ChainRulesTestUtils.jl): utilities for testing rules using finite differences.
17
21
18
-
!!! note "The whole field is a mess for terminology"
19
-
It isn't just ChainRules, it is everyone.
20
-
Internally ChainRules tries to be consistent.
21
-
Help with that is always welcomed.
22
+
AD systems depend on ChainRulesCore.jl to get access to tangent types and the core rule definition functionality (`frule` and `rrule`), and on ChainRules.jl to benefit from the collection of rules for Julia Base and the standard libraries.
22
23
23
-
!!! terminology "Primal"
24
-
Often we will talk about something as _primal_.
25
-
That means it is related to the original problem, not its derivative.
26
-
For example in `y = foo(x)`, `foo` is the _primal_ function, and computing `foo(x)` is doing the _primal_ computation.
27
-
`y` is the _primal_ return, and `x` is a _primal_ argument.
28
-
`typeof(y)` and `typeof(x)` are both _primal_ types.
24
+
Packages that just want to define rules only need to depend on [ChainRulesCore.jl](https://github.com/JuliaDiff/ChainRulesCore.jl), which is an exceptionally light dependency.
25
+
They should also have a test-only dependency on [ChainRulesTestUtils.jl](https://github.com/JuliaDiff/ChainRulesTestUtils.jl) to test the rules using finite differences.
29
26
27
+
Note that the packages with rules do not have to depend on AD systems, and neither do the AD systems have to depend on individual packages.
30
28
31
-
## `frule` and `rrule`
29
+
## Key functionality
30
+
31
+
Consider a relationship $y = f(x)$, where $f$ is some function.
32
+
Computing $y$ from $x$ is the original problem, called the _primal_ computation, in contrast to the problem of computing derivatives.
33
+
We say that the _primal function_ $f$ takes a _primal input_ $x$ and returns the _primal output_ $y$.
34
+
35
+
ChainRules rules are concerned with propagating _tangents_ of primal inputs to _tangents_ of primal outputs (`frule`, from forwards mode AD), and propagating _cotangents_ of primal outputs to _cotangents_ of primal inputs (`rrule`, from reverse mode AD).
36
+
To be able to do that, ChainRules also defines a small number of tangent types to represent tangents and cotangents.
37
+
38
+
!!! note "Tangents and cotangents"
39
+
Strictly speaking tangents, $ẋ = \frac{dx}{da}$, are propagated in `frule`s, and cotangents, $x̄ = \frac{da}{dx}$, are propagated in `rrule`s.
40
+
However, in practice there is rarely a need to distinguish between the two: both are represented by the same tangent types.
41
+
Thus, except when the detail might clarify, we refer to both as tangents.
32
42
33
43
!!! terminology "`frule` and `rrule`"
34
44
`frule` and `rrule` are ChainRules specific terms.
35
45
Their exact functioning is fairly ChainRules specific, though other tools have similar functions.
36
46
The core notion is sometimes called _custom AD primitives_, _custom adjoints_, _custom gradients_, _custom sensitivities_.
47
+
The whole field is a mess for terminology.
48
+
49
+
50
+
### Forward-mode AD rules (`frule`s)
51
+
52
+
If we know the value of $ẋ = \frac{dx}{da}$ for some $a$ and we want to know $ẏ = \frac{dy}{da}$, the [chain rule](https://en.wikipedia.org/wiki/Chain_rule) tells us that $ẏ = \frac{dy}{dx} ẋ$.
53
+
Intuitively, we are pushing the derivative forward.
54
+
This is the basis for forward-mode AD.
55
+
56
+
!!! note "frule"
57
+
The `frule` for $f$ encodes how to propagate the tangent of the primal input ($ẋ$) to the tangent of the primal output ($ẏ$).
58
+
59
+
The `frule` signature for a function `foo(args...; kwargs...)` is
where `y = foo(args; kwargs...)` is the primal output, and `∂Y` is the result of propagating the input tangents `Δself`, `Δargs...` forwards at the point in the domain of `foo` described by `args`.
67
+
This propagation is call the pushforward.
68
+
Often we will think of the `frule` as having the primal computation `y = foo(args...; kwargs...)`, and the pushforward `∂Y = pushforward(Δself, Δargs...)`,
69
+
even though they are not present in seperate forms in the code.
70
+
71
+
For example, the `frule` for `sin(x)` is:
72
+
```julia
73
+
functionfrule((_, Δx), ::typeof(sin), x)
74
+
returnsin(x), cos(x) * Δx
75
+
end
76
+
```
37
77
38
-
The rules are encoded as `frule`s and `rrule`s, for use in forward-mode and reverse-mode differentiation respectively.
78
+
### Reverse-mode AD rules (`rrule`s)
39
79
40
-
The `rrule` for some function `foo`, which takes the positional arguments `args` and keyword arguments `kwargs`, is written:
80
+
If we know the value of $ȳ = \frac{da}{dy}$ for some $a$ and we want to know $x̄ = \frac{da}{dx}$, the [chain rule](https://en.wikipedia.org/wiki/Chain_rule) tells us that $x̄ =ȳ \frac{dy}{dx}$.
81
+
Intuitively, we are pushing the derivative backward.
82
+
This is the basis for reverse-mode AD.
41
83
84
+
!!! note "rrule"
85
+
The `rrule` for $f$ encodes how to propagate the cotangents of the primal output ($ȳ$) to the cotangent of the primal input ($x̄$).
86
+
87
+
The `rrule` signature for a function `foo(args...; kwargs...)` is
42
88
```julia
43
89
functionrrule(::typeof(foo), args...; kwargs...)
44
90
...
45
91
return y, pullback
46
92
end
47
93
```
48
-
where `y` (the primal result) must be equal to `foo(args...; kwargs...)`.
49
-
`pullback` is a function to propagate the derivative information backwards at that point.
94
+
where `y` (the primal output) must be equal to `foo(args...; kwargs...)`.
95
+
`pullback` is a function to propagate the derivative information backwards at the point in the domain of `foo ` described by `args`.
50
96
That pullback function is used like:
51
97
`∂self, ∂args... = pullback(Δy)`
52
-
53
-
54
98
Almost always the _pullback_ will be declared locally within the `rrule`, and will be a _closure_ over some of the other arguments, and potentially over the primal result too.
and `∂Y` is the result of propagating the derivative information forwards at that point.
65
-
This propagation is call the pushforward.
66
-
Often we will think of the `frule` as having the primal computation `y = foo(args...; kwargs...)`, and the pushforward `∂Y = pushforward(Δself, Δargs...)`,
67
-
even though they are not present in seperate forms in the code.
68
-
69
107
70
108
!!! note "Why `rrule` returns a pullback but `frule` doesn't return a pushforward"
71
109
While `rrule` takes only the arguments to the original function (the primal arguments) and returns a function (the pullback) that operates with the derivative information, the `frule` does it all at once.
@@ -75,89 +113,22 @@ even though they are not present in seperate forms in the code.
75
113
In contrast, in reverse mode the derivative information needed by the pullback is about the primal function's output.
76
114
Thus the reverse mode returns the pullback function which the caller (usually an AD system) keeps hold of until derivative information about the output is available.
77
115
78
-
## Videos
79
-
80
-
For people who learn better by video we have a number of videos of talks we have given about the ChainRules project.
81
-
Note however, that the videos are frozen in time reflecting the state of the packages at the time they were recorded.
82
-
This documentation is the continously updated canonical source.
83
-
However, we have tried to note below each video notes on its correctness.
84
-
85
-
The talks that follow are in reverse chronological order (i.e. most recent video is first).
86
-
87
-
### EuroAD 2021: ChainRules.jl: AD system agnostic rules for JuliaLang
> The ChainRules project is a suite of JuliaLang packages that define custom primitives (i.e. rules) for doing AD in JuliaLang.
103
-
> Importantly it is AD system agnostic.
104
-
> It has proved successful in this goal.
105
-
> At present it works with about half a dozen different JuliaLang AD systems.
106
-
> It has been a long journey, but as of August 2021, the core packages have now hit version 1.0.
107
-
>
108
-
> This talk will go through why this is useful, the particular objectives the project had, and the challenges that had to be solved.
109
-
> This talk is not intended as an educational guide for users (For that see our 2021 JuliaCon talk: > Everything you need to know about ChainRules 1.0 (https://live.juliacon.org/talk/LWVB39)).
110
-
> Rather this talk is to share the insights we have had, and likely (inadvertently) the mistakes we have made, with the wider autodiff community.
111
-
> We believe these insights can be informative and useful to efforts in other languages and ecosystems.
112
-
113
-
114
-
### JuliaCon 2021: Everything you need to know about ChainRules 1.0
If you are just wanting to watch a video to learn all about ChainRules and how to use it, watch this one.
119
-
120
-
!!! note "Slide on opting out is incorrect"
121
-
Slide 42 is incorrect (`@no_rrule sum_array(A::Diagonal)`), in the ChainRulesCore 1.0 release the following syntax is used: `@opt_out rrule(::typeof(sum_array), A::Diagonal)`. This syntax allows us to include rule config information.
> ChainRules is an automatic differentiation (AD)-independent ecosystem for forward-, reverse-, and mixed-mode primitives. It comprises ChainRules.jl, a collection of primitives for Julia Base, ChainRulesCore.jl, the utilities for defining custom primitives, and ChainRulesTestUtils.jl, the utilities to test primitives using finite differences. This talk provides brief updates on the ecosystem since last year and focuses on when and how to write and test custom primitives.
116
+
### Tangent types
131
117
118
+
The types of (co)-tangents depend on the types of the primals.
119
+
Scalar primals are represented by scalar tangents (e.g. `Float64` tangent for a `Float64` primal).
120
+
Vector, matrix, and higher rank tensor primals can be represented by vector, matrix and tensor tangents.
Additionally, for signalling semantics, we distinguish between two tangent types representing a zero tangent.
125
+
[`NoTangent`](@ref) type represent situtations in which the tangent space does not exist, e.g. an index into an array can not be perturbed.
126
+
[`ZeroTangent`](@ref) is used for cases where the tangent happens to be zero, e.g. because the primal argument is not used in the computation.
155
127
156
-
Abstract:
157
-
> The ChainRules project allows package authors to write rules for custom sensitivities (sometimes called custom adjoints) in a way that is not dependent on any particular autodiff (AD) package.
158
-
> It allows authors of AD packages to access a wealth of prewritten custom sensitivities, saving them the effort of writing them all out themselves.
159
-
> ChainRules is the successor to DiffRules.jl and is the native rule system currently used by ForwardDiff2, Zygote and soon ReverseDiff
128
+
We also define [`Thunk`](@ref)s to allow certain optimisation.
129
+
`Thunk`s are a wrapper over a computation that can potentially be avoided, depending on the downstream use.
160
130
131
+
See the section on [tangent types](@ref tangents) for more details.
0 commit comments