Skip to content

Commit cd7516c

Browse files
oxinaboxnickrobinson251sethaxen
committed
Apply suggestions from code review
Co-authored-by: Nick Robinson <npr251@gmail.com> Co-authored-by: Seth Axen <seth.axen@gmail.com>
1 parent 0efe5e5 commit cd7516c

File tree

1 file changed

+51
-58
lines changed

1 file changed

+51
-58
lines changed

docs/src/design/changing_the_primal.md

Lines changed: 51 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -1,38 +1,37 @@
11
# Design Notes: Why can you change the primal computation?
22

3-
These design notes are to help you understand why ChainRules why [`rrule`](@ref).
4-
It explains why we have a function, that includes the computation of the primal result, and that returns the pullback as a closure.
5-
This structure is perhaps surprising to some AD authors, who might expect just a function that performs the pullback.
6-
It is particularly notable that you are able to change the primal computation when defining the `rrule`.
7-
We will illustrate in this document why this is a crucial capacity.
3+
These design notes are to help you understand ChainRules.jl's [`rrule`](@ref) function.
4+
It explains why we have a `rrule` function that returns both the primal result (i.e. the output for the forward pass) and the pullback as a closure.
5+
It might be surprising to some AD authors, who might expect just a function that performs the pullback, that the `rrule` function computes the primal result as well as the pullback.
6+
In particularly, `rrule` allows you to _change_ how the primal result is computed.
7+
We will illustrate in this document why being able to change the computation of the primal is crucial for efficient AD.
88

99

10-
!!! note what about `frule`?
10+
!!! note "What about `frule`?"
1111
Discussion here is focused on on reverse mode and `rrule`.
12-
Similar concerns to apply to to forward mode and `frule`.
12+
Similar concerns do apply to forward mode and `frule`.
1313
In forward mode these concerns lead to the fusing of the `pushforward` into `frule`.
14-
All the examples given here also apply in forwards mode.
15-
In fact in forwards mode there are even more opportunities to take advantage of sharing work between the primal and derivative computations.
14+
All the examples given here also apply in forward mode.
15+
In fact in forward mode there are even more opportunities to take advantage of sharing work between the primal and derivative computations.
1616
A particularly notable example is in efficiently calculating the pushforward of solving a differential equation via expanding the system of equations to also include the derivatives before solving it.
1717

1818

1919

2020

21-
2221
## The Journey to `rrule`
2322

24-
Let's imagine a different system for rules, one that doesn't let you do this.
23+
Let's imagine a different system for rules, one that doesn't let you define the computation of the primal.
2524
This system is what a lot of AD systems have.
26-
It is what [Nabla.jl](https://github.com/invenia/Nabla.jl/)[^1] had originally.
27-
We will have a primal (i.e. forward) pass that directly executes the primal function and just records its _inputs_ and its _output_ (as well as the _primal function_ itself) onto the tape.[^2].
25+
It is what [Nabla.jl](https://github.com/invenia/Nabla.jl/) had originally.[^1]
26+
We will have a primal (i.e. forward) pass that directly executes the primal function and just records the primal _function_, its _inputs_ and its _output_ onto the tape.[^2].
2827
Then during the gradient (i.e. reverse) pass it has a function which receives those records from the tape along with the sensitivity of the output, and gives back the sensitivity of the input.
2928
We will call this function `pullback_at`, as it pulls back the sensitivity at a given primal point.
3029
To make this concrete:
3130
```julia
3231
y = f(x) # primal program
3332
= pullback_at(f, x, y, ȳ)
3433
```
35-
To illustrate this we will use throughout this document examples for `sin` and for the [logistic sigmoid](https://en.wikipedia.org/wiki/Logistic_function#Derivative).
34+
Let's illustrate this with examples for `sin` and for the [logistic sigmoid](https://en.wikipedia.org/wiki/Logistic_function#Derivative).
3635

3736
```@raw html
3837
<details><summary>Example for `sin`</summary>
@@ -51,18 +50,18 @@ pullback_at(::typeof(sin), x, y, ȳ) = ȳ * cos(x)
5150
```
5251

5352
```julia
54-
σ(x) = 1/(1 + exp(-x)) # = exp(x)/(1+exp(x))
53+
σ(x) = 1/(1 + exp(-x)) # = exp(x) / (1 + exp(x))
5554
y = σ(x)
56-
pullback_at(::typeof(σ), x, y, ȳ) =* y * σ(-x) # i.e. ȳ * σ(x) * σ(-x)
55+
pullback_at(::typeof(σ), x, y, ȳ) =* y * σ(-x) # = ȳ * σ(x) * σ(-x)
5756
```
58-
Notice that here we are in the `pullback_at` not only using `x` but also `y` the primal output.
57+
Notice that in `pullback_at` we are not only using input `x` but also using the primal output `y` .
5958
This is a nice bit of symmetry that shows up around `exp`.
6059
```@raw html
6160
</details>
6261
```
6362

6463
Now let's consider why we implement `rrule`s in the first place.
65-
One key reason [^3] is to allow us to insert our domain knowledge to do better than the AD would do just by breaking everything down into `+`, `*`, etc.
64+
One key reason is to insert domain knowledge so as to compute the derivative more efficiently than AD would just by breaking everything down into `+`, `*`, etc.[^3]
6665
What insights do we have about `sin` and `cos`?
6766
What about using `sincos`?
6867
```@raw html
@@ -130,12 +129,12 @@ julia> 5.367 + 1.255 + 1.256
130129
</details>
131130
```
132131

133-
So we are talking about a 30-40%[^4] speed-up from these optimizations.
134-
It's faster to compute `sin` and `cos` at the same time via `sincos` than it is to compute them one after the other.
132+
So we are talking about a 30-40% speed-up from these optimizations.[^4]
133+
134+
It is faster to compute `sin` and `cos` at the same time via `sincos` than it is to compute them one after the other.
135135
And it is faster to reuse the `exp(x)` in computing `σ(x)` and `σ(-x)`.
136136
How can we incorporate this insight into our system?
137-
We know we can compute both of these in the primal, because they only depend on `x` --- we don't need to know ``.
138-
But there is nowhere to put it that is accessible both to the primal pass and the gradient pass code.
137+
We know we can compute both of these in the primal — because they only depend on `x` and not on `` — but there is nowhere to put them that is accessible both to the primal pass and the gradient pass code.
139138

140139

141140
What if we introduced some variable called `intermediates` that is also recorded onto the tape during the primal pass?
@@ -170,11 +169,11 @@ pullback_at(::typeof(sin), x, y, ȳ, intermediates) = ȳ * intermediates.cx
170169
```julia
171170
function augmented_primal(::typeof(σ), x)
172171
ex = exp(x)
173-
y = ex/(1 + ex)
172+
y = ex / (1 + ex)
174173
return y, (; ex=ex) # use a NamedTuple for the intermediates
175174
end
176175

177-
pullback_at(::typeof(σ), x, y, ȳ, intermediates) =* y /(1 + intermediates.ex)
176+
pullback_at(::typeof(σ), x, y, ȳ, intermediates) =* y / (1 + intermediates.ex)
178177
```
179178
```@raw html
180179
</details>
@@ -183,26 +182,23 @@ pullback_at(::typeof(σ), x, y, ȳ, intermediates) = ȳ * y /(1 + intermediate
183182
Cool!
184183
That lets us do what we wanted.
185184
We net decreased the time it takes to run the primal and gradient passes.
186-
We have now demonstrated the title question of why we needed to be able to modify the primal pass.
187-
We will go into that more later and have some more usage examples.
188-
But first let's continue to see how we go from that `augmented_primal` to `pullback_at` to [`rrule`](@ref).
185+
We have now demonstrated the title question of why we want to be able to modify the primal pass.
186+
We will go into that more later and have some more usage examples, but first let's continue to see how we go from `augmented_primal` and `pullback_at` to [`rrule`](@ref).
189187

190188
One thing we notice when looking at `pullback_at` is it really is starting to have a lot of arguments.
191-
It had a fair few already, and now we are adding `intermediates` as well.
192-
Not to mention this is a fairly simple function, only 1 input, no keyword arguments.
193-
Furthermore, we don't even use all of them all the time.
194-
The original new code for pulling back `sin` no longer needs the `x`, and it never needed `y` (though `sigmoid` does).
195-
The function signature for `pullback_at` was already rather long with the primal inputs and outputs and the sensitivity.
196-
Now it is even longer since we added `intermediates`.
197-
Further, storing all all those things on the tape is using up extra memory.
189+
It had a fair few already, and now we are adding `intermediates` as well, making it even more unwieldy.
190+
Not to mention these are fairly simple example, the `sin` and `σ` functions have 1 input and no keyword arguments.
191+
Furthermore, we often don't even use all of the arguments to `pullback_at`.
192+
The new code for pulling back `sin` — which uses `sincos` and `intermediates` — no longer needs `x`, and it never needed `y` (though sigmoid `σ` does).
193+
And storing all these things on the tape — inputs, outputs, sensitivities, intermediates — is using up extra memory.
198194
What if we generalized the idea of the `intermediate` named tuple, and had a struct that just held anything we might want put on the tape.
199195
```julia
200196
struct PullbackMemory{P, S}
201197
primal_function::P
202198
state::S
203199
end
204200
# convenience constructor:
205-
Memory(primal_function; state...) = PullbackMemory(primal_function, state)
201+
PullbackMemory(primal_function; state...) = PullbackMemory(primal_function, state)
206202
# convenience accessor so that `m.x` is same as `m.state.x`
207203
Base.getproperty(m::PullbackMemory, propname) = getproperty(getfield(m, :state), propname)
208204
```
@@ -236,23 +232,23 @@ pullback_at(pb::PullbackMemory{typeof(sin)}, ȳ) = ȳ * pb.cx
236232
```julia
237233
function augmented_primal(::typeof(σ), x)
238234
ex = exp(x)
239-
y = ex/(1 + ex)
235+
y = ex / (1 + ex)
240236
return y, PullbackMemory(σ; y=y, ex=ex)
241237
end
242238

243-
pullback_at(pb::PullbackMemory{typeof(σ)}, ȳ) =* pb.y/(1 + pb.ex)
239+
pullback_at(pb::PullbackMemory{typeof(σ)}, ȳ) =* pb.y / (1 + pb.ex)
244240
```
245241
```@raw html
246242
</details>
247243
```
248244

249-
I think that looks pretty nice.
245+
That now looks much simpler; `pullback_at` only ever has 2 arguments.
250246

251-
One way we could make it look a bit nicer for usage is if the `PullbackMemory` was actually a callable object. `pullback_at` only has the 2 arguments.
247+
One way we could make it nicer to use is by making `PullbackMemory` a callable object.
252248
Conceptually the `PullbackMemory` is a fixed thing it the contents of the tape for a particular operation.
253249
It is fully determined by the end of the primal pass.
254250
The during the gradient (reverse) pass the `PullbackMemory` is used to successively compute the `` argument.
255-
So it makes sense to have `PullbackMemory` being a callable object that acts on the sensitivity.
251+
So it makes sense to make `PullbackMemory` a callable object that acts on the sensitivity.
256252
We can do that via call overloading:
257253
```julia
258254
y = f(x) # primal program
@@ -281,22 +277,21 @@ end
281277
```julia
282278
function augmented_primal(::typeof(σ), x)
283279
ex = exp(x)
284-
y = ex/(1 + ex)
280+
y = ex / (1 + ex)
285281
return y, PullbackMemory(σ; y=y, ex=ex)
286282
end
287283

288-
(pb::PullbackMemory{typeof(σ)})(ȳ) =* pb.y/(1 + pb.ex)
284+
(pb::PullbackMemory{typeof(σ)})(ȳ) =* pb.y / (1 + pb.ex)
289285
```
290286
```@raw html
291287
</details>
292288
```
293289

294-
Those looking closely will spot what we have done here.
295-
We now have an object (`pb`) that acts on the cotangent of the output of the primal (``) to give us the cotangent of the input of the primal function (``).
290+
Let's recap what we have done here.
291+
We now have an object `pb` that acts on the cotangent of the output of the primal `` to give us the cotangent of the input of the primal function ``.
296292
_`pb` is not just the **memory** of state required for the `pullback`, it **is** the pullback._
297293

298-
We have one final thing to do.
299-
Let's think about making the code easy to modify.
294+
We have one final thing to do, which is to think about how we make the code easy to modify.
300295
Let's go back and think about the changes we would have make to go from our original way of writing that only used the inputs/outputs, to one that used the intermediate state.
301296

302297
```@raw html
@@ -348,20 +343,18 @@ end
348343
(NB: there is actually a further optimization that can be made to the logistic sigmoid, to avoid remembering two things and just remember one.
349344
As an exercise to the reader, consider how the code would need to be changed and where.)
350345

351-
We need to make a series of changes.
352-
We need to update what work is done in the primal to compute the intermediate values.
353-
We need to update what was stored in the `PullbackMemory`.
354-
And we need to update the the function that applies the pullback so it uses the new thing that was stored.
346+
We need to make a series of changes:
347+
* update what work is done in the primal, to compute the intermediate values.
348+
* update what is stored in the `PullbackMemory`.
349+
* update the function that applies the pullback so it uses the new thing that was stored.
355350
It's important these parts all stay in sync.
356351
It's not too bad for this simple example with just one or two things to remember.
357-
For more complicated multi-argument functions, like will be talked about below, you often end up needing to remember half a dozen things, like sizes and indices relating to each input/output.
358-
So it gets a little more fiddly to make sure you remember all the things you need to and give them the same name in both places.
352+
For more complicated multi-argument functions, which we will show below, you often end up needing to remember half a dozen things, like sizes and indices relating to each input/output, so it gets a little more fiddly to make sure you remember all the things you need to and give them the same name in both places.
359353
_Is there a way we can automatically just have all the things we use remembered for us?_
354+
Surprisingly for such a specific request, there actually is: a closure.
360355

361-
Surprisingly for such a specific request, there actually is.
362-
This is a closure.
363356
A closure in Julia is a callable structure that automatically contains a field for every object from its parent scope that is used in its body.
364-
There are [incredible ways to abuse this](https://invenia.github.io/blog/2019/10/30/julialang-features-part-1#closures-give-us-classic-object-oriented-programming); but here we can in-fact use closures exactly as they are intended.
357+
There are [incredible ways to abuse this](https://invenia.github.io/blog/2019/10/30/julialang-features-part-1#closures-give-us-classic-object-oriented-programming); but here we can use closures exactly as they are intended.
365358
Replacing `PullbackMemory` with a closure that works the same way lets us avoid having to manually control what is remembered _and_ lets us avoid separately writing the call overload.
366359

367360
```@raw html
@@ -384,8 +377,8 @@ end
384377
```julia
385378
function augmented_primal(::typeof(σ), x)
386379
ex = exp(x)
387-
y = ex/(1 + ex)
388-
pb =->* y/(1 + ex) # pullback closure. closes over `y` and `ex`
380+
y = ex / (1 + ex)
381+
pb =->* y / (1 + ex) # pullback closure. closes over `y` and `ex`
389382
return y, pb
390383
end
391384
```
@@ -399,7 +392,7 @@ All that is left is a rename and some extra conventions around multiple outputs
399392

400393
This has been a journey into how we get to [`rrule`](@ref) as it is defined in `ChainRulesCore`.
401394
We started with an unaugmented primal function and a `pullback_at` function that only saw the inputs and outputs of the primal.
402-
We realized a key limitation of this was that we couldn't share computational work between the primal and and gradient passes.
395+
We realized a key limitation of this was that we couldn't share computational work between the primal and gradient passes.
403396
To solve this we introduced the notation of some `intermediate` that is shared from the primal to the pullback.
404397
We successively improved that idea, first by making it a type that held everything that is needed for the pullback: the `PullbackMemory`, which we then made callable, so it was itself the pullback.
405398
Finally, we replaced that separate callable structure with a closure, which kept everything in one place and made it more convenient.

0 commit comments

Comments
 (0)