You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: docs/src/design/changing_the_primal.md
+51-58Lines changed: 51 additions & 58 deletions
Original file line number
Diff line number
Diff line change
@@ -1,38 +1,37 @@
1
1
# Design Notes: Why can you change the primal computation?
2
2
3
-
These design notes are to help you understand why ChainRules why [`rrule`](@ref).
4
-
It explains why we have a function, that includes the computation of the primal result, and that returns the pullback as a closure.
5
-
This structure is perhaps surprising to some AD authors, who might expect just a function that performs the pullback.
6
-
It is particularly notable that you are able to change the primal computation when defining the `rrule`.
7
-
We will illustrate in this document why this is a crucial capacity.
3
+
These design notes are to help you understand ChainRules.jl's [`rrule`](@ref) function.
4
+
It explains why we have a `rrule`function that returns both the primal result (i.e. the output for the forward pass) and the pullback as a closure.
5
+
It might be surprising to some AD authors, who might expect just a function that performs the pullback, that the `rrule` function computes the primal result as well as the pullback.
6
+
In particularly, `rrule` allows you to _change_ how the primal result is computed.
7
+
We will illustrate in this document why being able to change the computation of the primal is crucial for efficient AD.
8
8
9
9
10
-
!!! note what about `frule`?
10
+
!!! note "What about `frule`?"
11
11
Discussion here is focused on on reverse mode and `rrule`.
12
-
Similar concerns to apply to to forward mode and `frule`.
12
+
Similar concerns do apply to forward mode and `frule`.
13
13
In forward mode these concerns lead to the fusing of the `pushforward` into `frule`.
14
-
All the examples given here also apply in forwards mode.
15
-
In fact in forwards mode there are even more opportunities to take advantage of sharing work between the primal and derivative computations.
14
+
All the examples given here also apply in forward mode.
15
+
In fact in forward mode there are even more opportunities to take advantage of sharing work between the primal and derivative computations.
16
16
A particularly notable example is in efficiently calculating the pushforward of solving a differential equation via expanding the system of equations to also include the derivatives before solving it.
17
17
18
18
19
19
20
20
21
-
22
21
## The Journey to `rrule`
23
22
24
-
Let's imagine a different system for rules, one that doesn't let you do this.
23
+
Let's imagine a different system for rules, one that doesn't let you define the computation of the primal.
25
24
This system is what a lot of AD systems have.
26
-
It is what [Nabla.jl](https://github.com/invenia/Nabla.jl/)[^1] had originally.
27
-
We will have a primal (i.e. forward) pass that directly executes the primal function and just records its _inputs_ and its _output_ (as well as the _primal function_ itself) onto the tape.[^2].
25
+
It is what [Nabla.jl](https://github.com/invenia/Nabla.jl/) had originally.[^1]
26
+
We will have a primal (i.e. forward) pass that directly executes the primal function and just records the primal _function_, its _inputs_ and its _output_ onto the tape.[^2].
28
27
Then during the gradient (i.e. reverse) pass it has a function which receives those records from the tape along with the sensitivity of the output, and gives back the sensitivity of the input.
29
28
We will call this function `pullback_at`, as it pulls back the sensitivity at a given primal point.
30
29
To make this concrete:
31
30
```julia
32
31
y =f(x) # primal program
33
32
x̄ =pullback_at(f, x, y, ȳ)
34
33
```
35
-
To illustrate this we will use throughout this document examples for `sin` and for the [logistic sigmoid](https://en.wikipedia.org/wiki/Logistic_function#Derivative).
34
+
Let's illustrate this with examples for `sin` and for the [logistic sigmoid](https://en.wikipedia.org/wiki/Logistic_function#Derivative).
pullback_at(::typeof(σ), x, y, ȳ) = ȳ * y *σ(-x) #i.e. ȳ * σ(x) * σ(-x)
55
+
pullback_at(::typeof(σ), x, y, ȳ) = ȳ * y *σ(-x) #= ȳ * σ(x) * σ(-x)
57
56
```
58
-
Notice that here we are in the `pullback_at` not only using `x` but also `y` the primal output.
57
+
Notice that in `pullback_at`we are not only using input `x` but also using the primal output`y`.
59
58
This is a nice bit of symmetry that shows up around `exp`.
60
59
```@raw html
61
60
</details>
62
61
```
63
62
64
63
Now let's consider why we implement `rrule`s in the first place.
65
-
One key reason [^3]is to allow us to insert our domain knowledge to do better than the AD would do just by breaking everything down into `+`, `*`, etc.
64
+
One key reason is to insert domain knowledge so as to compute the derivative more efficiently than AD would just by breaking everything down into `+`, `*`, etc.[^3]
So we are talking about a 30-40%[^4] speed-up from these optimizations.
134
-
It's faster to compute `sin` and `cos` at the same time via `sincos` than it is to compute them one after the other.
132
+
So we are talking about a 30-40% speed-up from these optimizations.[^4]
133
+
134
+
It is faster to compute `sin` and `cos` at the same time via `sincos` than it is to compute them one after the other.
135
135
And it is faster to reuse the `exp(x)` in computing `σ(x)` and `σ(-x)`.
136
136
How can we incorporate this insight into our system?
137
-
We know we can compute both of these in the primal, because they only depend on `x` --- we don't need to know `ȳ`.
138
-
But there is nowhere to put it that is accessible both to the primal pass and the gradient pass code.
137
+
We know we can compute both of these in the primal — because they only depend on `x` and not on `ȳ` — but there is nowhere to put them that is accessible both to the primal pass and the gradient pass code.
139
138
140
139
141
140
What if we introduced some variable called `intermediates` that is also recorded onto the tape during the primal pass?
return y, (; ex=ex) # use a NamedTuple for the intermediates
175
174
end
176
175
177
-
pullback_at(::typeof(σ), x, y, ȳ, intermediates) = ȳ * y /(1+ intermediates.ex)
176
+
pullback_at(::typeof(σ), x, y, ȳ, intermediates) = ȳ * y /(1+ intermediates.ex)
178
177
```
179
178
```@raw html
180
179
</details>
@@ -183,26 +182,23 @@ pullback_at(::typeof(σ), x, y, ȳ, intermediates) = ȳ * y /(1 + intermediate
183
182
Cool!
184
183
That lets us do what we wanted.
185
184
We net decreased the time it takes to run the primal and gradient passes.
186
-
We have now demonstrated the title question of why we needed to be able to modify the primal pass.
187
-
We will go into that more later and have some more usage examples.
188
-
But first let's continue to see how we go from that `augmented_primal` to `pullback_at` to [`rrule`](@ref).
185
+
We have now demonstrated the title question of why we want to be able to modify the primal pass.
186
+
We will go into that more later and have some more usage examples, but first let's continue to see how we go from `augmented_primal` and `pullback_at` to [`rrule`](@ref).
189
187
190
188
One thing we notice when looking at `pullback_at` is it really is starting to have a lot of arguments.
191
-
It had a fair few already, and now we are adding `intermediates` as well.
192
-
Not to mention this is a fairly simple function, only 1 input, no keyword arguments.
193
-
Furthermore, we don't even use all of them all the time.
194
-
The original new code for pulling back `sin` no longer needs the `x`, and it never needed `y` (though `sigmoid` does).
195
-
The function signature for `pullback_at` was already rather long with the primal inputs and outputs and the sensitivity.
196
-
Now it is even longer since we added `intermediates`.
197
-
Further, storing all all those things on the tape is using up extra memory.
189
+
It had a fair few already, and now we are adding `intermediates` as well, making it even more unwieldy.
190
+
Not to mention these are fairly simple example, the `sin` and `σ` functions have 1 input and no keyword arguments.
191
+
Furthermore, we often don't even use all of the arguments to `pullback_at`.
192
+
The new code for pulling back `sin` — which uses `sincos` and `intermediates` — no longer needs `x`, and it never needed `y` (though sigmoid `σ` does).
193
+
And storing all these things on the tape — inputs, outputs, sensitivities, intermediates — is using up extra memory.
198
194
What if we generalized the idea of the `intermediate` named tuple, and had a struct that just held anything we might want put on the tape.
Those looking closely will spot what we have done here.
295
-
We now have an object (`pb`) that acts on the cotangent of the output of the primal (`ȳ`) to give us the cotangent of the input of the primal function (`x̄`).
290
+
Let's recap what we have done here.
291
+
We now have an object `pb` that acts on the cotangent of the output of the primal `ȳ` to give us the cotangent of the input of the primal function `x̄`.
296
292
_`pb` is not just the **memory** of state required for the `pullback`, it **is** the pullback._
297
293
298
-
We have one final thing to do.
299
-
Let's think about making the code easy to modify.
294
+
We have one final thing to do, which is to think about how we make the code easy to modify.
300
295
Let's go back and think about the changes we would have make to go from our original way of writing that only used the inputs/outputs, to one that used the intermediate state.
301
296
302
297
```@raw html
@@ -348,20 +343,18 @@ end
348
343
(NB: there is actually a further optimization that can be made to the logistic sigmoid, to avoid remembering two things and just remember one.
349
344
As an exercise to the reader, consider how the code would need to be changed and where.)
350
345
351
-
We need to make a series of changes.
352
-
We need to update what work is done in the primal to compute the intermediate values.
353
-
We need to update what was stored in the `PullbackMemory`.
354
-
And we need to update the the function that applies the pullback so it uses the new thing that was stored.
346
+
We need to make a series of changes:
347
+
*update what work is done in the primal, to compute the intermediate values.
348
+
*update what is stored in the `PullbackMemory`.
349
+
*update the function that applies the pullback so it uses the new thing that was stored.
355
350
It's important these parts all stay in sync.
356
351
It's not too bad for this simple example with just one or two things to remember.
357
-
For more complicated multi-argument functions, like will be talked about below, you often end up needing to remember half a dozen things, like sizes and indices relating to each input/output.
358
-
So it gets a little more fiddly to make sure you remember all the things you need to and give them the same name in both places.
352
+
For more complicated multi-argument functions, which we will show below, you often end up needing to remember half a dozen things, like sizes and indices relating to each input/output, so it gets a little more fiddly to make sure you remember all the things you need to and give them the same name in both places.
359
353
_Is there a way we can automatically just have all the things we use remembered for us?_
354
+
Surprisingly for such a specific request, there actually is: a closure.
360
355
361
-
Surprisingly for such a specific request, there actually is.
362
-
This is a closure.
363
356
A closure in Julia is a callable structure that automatically contains a field for every object from its parent scope that is used in its body.
364
-
There are [incredible ways to abuse this](https://invenia.github.io/blog/2019/10/30/julialang-features-part-1#closures-give-us-classic-object-oriented-programming); but here we can in-fact use closures exactly as they are intended.
357
+
There are [incredible ways to abuse this](https://invenia.github.io/blog/2019/10/30/julialang-features-part-1#closures-give-us-classic-object-oriented-programming); but here we can use closures exactly as they are intended.
365
358
Replacing `PullbackMemory` with a closure that works the same way lets us avoid having to manually control what is remembered _and_ lets us avoid separately writing the call overload.
366
359
367
360
```@raw html
@@ -384,8 +377,8 @@ end
384
377
```julia
385
378
functionaugmented_primal(::typeof(σ), x)
386
379
ex =exp(x)
387
-
y = ex/(1+ ex)
388
-
pb = ȳ -> ȳ * y/(1+ ex) # pullback closure. closes over `y` and `ex`
380
+
y = ex/(1+ ex)
381
+
pb = ȳ -> ȳ * y/(1+ ex) # pullback closure. closes over `y` and `ex`
389
382
return y, pb
390
383
end
391
384
```
@@ -399,7 +392,7 @@ All that is left is a rename and some extra conventions around multiple outputs
399
392
400
393
This has been a journey into how we get to [`rrule`](@ref) as it is defined in `ChainRulesCore`.
401
394
We started with an unaugmented primal function and a `pullback_at` function that only saw the inputs and outputs of the primal.
402
-
We realized a key limitation of this was that we couldn't share computational work between the primal and and gradient passes.
395
+
We realized a key limitation of this was that we couldn't share computational work between the primal and gradient passes.
403
396
To solve this we introduced the notation of some `intermediate` that is shared from the primal to the pullback.
404
397
We successively improved that idea, first by making it a type that held everything that is needed for the pullback: the `PullbackMemory`, which we then made callable, so it was itself the pullback.
405
398
Finally, we replaced that separate callable structure with a closure, which kept everything in one place and made it more convenient.
0 commit comments