@@ -148,8 +148,8 @@ x̄ = pullback_at(f, x, y, ȳ, intermediates)
148
148
```
149
149
``` julia
150
150
function augmented_primal (:: typeof (sin), x)
151
- y, cx = sincos (x)
152
- return y, (; cx= cx) # use a NamedTuple for the intermediates
151
+ y, cx = sincos (x)
152
+ return y, (; cx= cx) # use a NamedTuple for the intermediates
153
153
end
154
154
155
155
pullback_at (:: typeof (sin), x, y, ȳ, intermediates) = ȳ * intermediates. cx
@@ -163,9 +163,9 @@ pullback_at(::typeof(sin), x, y, ȳ, intermediates) = ȳ * intermediates.cx
163
163
```
164
164
``` julia
165
165
function augmented_primal (:: typeof (σ), x)
166
- ex = exp (x)
167
- y = ex / (1 + ex)
168
- return y, (; ex= ex) # use a NamedTuple for the intermediates
166
+ ex = exp (x)
167
+ y = ex / (1 + ex)
168
+ return y, (; ex= ex) # use a NamedTuple for the intermediates
169
169
end
170
170
171
171
pullback_at (:: typeof (σ), x, y, ȳ, intermediates) = ȳ * y / (1 + intermediates. ex)
@@ -189,8 +189,8 @@ And storing all these things on the tape — inputs, outputs, sensitivities, int
189
189
What if we generalized the idea of the ` intermediate ` named tuple, and had ` augmented_primal ` return a struct that just held anything we might want put on the tape.
190
190
``` julia
191
191
struct PullbackMemory{P, S}
192
- primal_function:: P
193
- state:: S
192
+ primal_function:: P
193
+ state:: S
194
194
end
195
195
# convenience constructor:
196
196
PullbackMemory (primal_function; state... ) = PullbackMemory (primal_function, state)
@@ -211,8 +211,8 @@ which is much cleaner.
211
211
```
212
212
``` julia
213
213
function augmented_primal (:: typeof (sin), x)
214
- y, cx = sincos (x)
215
- return y, PullbackMemory (sin; cx= cx)
214
+ y, cx = sincos (x)
215
+ return y, PullbackMemory (sin; cx= cx)
216
216
end
217
217
218
218
pullback_at (pb:: PullbackMemory{typeof(sin)} , ȳ) = ȳ * pb. cx
@@ -226,9 +226,9 @@ pullback_at(pb::PullbackMemory{typeof(sin)}, ȳ) = ȳ * pb.cx
226
226
```
227
227
``` julia
228
228
function augmented_primal (:: typeof (σ), x)
229
- ex = exp (x)
230
- y = ex / (1 + ex)
231
- return y, PullbackMemory (σ; y= y, ex= ex)
229
+ ex = exp (x)
230
+ y = ex / (1 + ex)
231
+ return y, PullbackMemory (σ; y= y, ex= ex)
232
232
end
233
233
234
234
pullback_at (pb:: PullbackMemory{typeof(σ)} , ȳ) = ȳ * pb. y / (1 + pb. ex)
@@ -256,8 +256,8 @@ x̄ = pb(ȳ)
256
256
```
257
257
``` julia
258
258
function augmented_primal (:: typeof (sin), x)
259
- y, cx = sincos (x)
260
- return y, PullbackMemory (sin; cx= cx)
259
+ y, cx = sincos (x)
260
+ return y, PullbackMemory (sin; cx= cx)
261
261
end
262
262
(pb:: PullbackMemory{typeof(sin)} )(ȳ) = ȳ * pb. cx
263
263
```
271
271
```
272
272
``` julia
273
273
function augmented_primal (:: typeof (σ), x)
274
- ex = exp (x)
275
- y = ex / (1 + ex)
276
- return y, PullbackMemory (σ; y= y, ex= ex)
274
+ ex = exp (x)
275
+ y = ex / (1 + ex)
276
+ return y, PullbackMemory (σ; y= y, ex= ex)
277
277
end
278
278
279
279
(pb:: PullbackMemory{typeof(σ)} )(ȳ) = ȳ * pb. y / (1 + pb. ex)
@@ -295,16 +295,16 @@ Let's go back and think about the changes we would have make to go from our orig
295
295
To rewrite that original formulation in the new pullback form we have:
296
296
``` julia
297
297
function augmented_primal (:: typeof (sin), x)
298
- y = sin (x)
299
- return y, PullbackMemory (sin; x= x)
298
+ y = sin (x)
299
+ return y, PullbackMemory (sin; x= x)
300
300
end
301
301
(pb:: PullbackMemory )(ȳ) = ȳ * cos (pb. x)
302
302
```
303
303
To go from that to:
304
304
``` julia
305
305
function augmented_primal (:: typeof (sin), x)
306
- y, cx = sincos (x)
307
- return y, PullbackMemory (sin; cx= cx)
306
+ y, cx = sincos (x)
307
+ return y, PullbackMemory (sin; cx= cx)
308
308
end
309
309
(pb:: PullbackMemory )(ȳ) = ȳ * pb. cx
310
310
```
@@ -317,17 +317,17 @@ end
317
317
```
318
318
``` julia
319
319
function augmented_primal (:: typeof (σ), x)
320
- y = σ (x)
321
- return y, PullbackMemory (σ; y= y, x= x)
320
+ y = σ (x)
321
+ return y, PullbackMemory (σ; y= y, x= x)
322
322
end
323
323
(pb:: PullbackMemory{typeof(σ)} )(ȳ) = ȳ * pb. y * σ (- pb. x)
324
324
```
325
325
to get to:
326
326
``` julia
327
327
function augmented_primal (:: typeof (σ), x)
328
- ex = exp (x)
329
- y = ex/ (1 + ex)
330
- return y, PullbackMemory (σ; y= y, ex= ex)
328
+ ex = exp (x)
329
+ y = ex/ (1 + ex)
330
+ return y, PullbackMemory (σ; y= y, ex= ex)
331
331
end
332
332
(pb:: PullbackMemory{typeof(σ)} )(ȳ) = ȳ * pb. y/ (1 + pb. ex)
333
333
```
@@ -356,9 +356,9 @@ Replacing `PullbackMemory` with a closure that works the same way lets us avoid
356
356
```
357
357
``` julia
358
358
function augmented_primal (:: typeof (sin), x)
359
- y, cx = sincos (x)
360
- pb = ȳ -> cx * ȳ # pullback closure. closes over `cx`
361
- return y, pb
359
+ y, cx = sincos (x)
360
+ pb = ȳ -> cx * ȳ # pullback closure. closes over `cx`
361
+ return y, pb
362
362
end
363
363
```
364
364
``` @raw html
@@ -370,10 +370,10 @@ end
370
370
```
371
371
``` julia
372
372
function augmented_primal (:: typeof (σ), x)
373
- ex = exp (x)
374
- y = ex / (1 + ex)
375
- pb = ȳ -> ȳ * y / (1 + ex) # pullback closure. closes over `y` and `ex`
376
- return y, pb
373
+ ex = exp (x)
374
+ y = ex / (1 + ex)
375
+ pb = ȳ -> ȳ * y / (1 + ex) # pullback closure. closes over `y` and `ex`
376
+ return y, pb
377
377
end
378
378
```
379
379
``` @raw html
0 commit comments