@@ -173,6 +173,24 @@ Base.iterate(x::One) = (x, nothing)
173
173
Base. iterate (:: One , :: Any ) = nothing
174
174
175
175
176
+ # ####
177
+ # #### `AbstractThunk
178
+ # ####
179
+ abstract type AbstractThunk <: AbstractDifferential end
180
+
181
+ Base. Broadcast. broadcastable (x:: AbstractThunk ) = broadcastable (extern (x))
182
+
183
+ @inline function Base. iterate (x:: AbstractThunk )
184
+ externed = extern (x)
185
+ element, state = iterate (externed)
186
+ return element, (externed, state)
187
+ end
188
+
189
+ @inline function Base. iterate (:: AbstractThunk , (externed, state))
190
+ element, new_state = iterate (externed, state)
191
+ return element, (externed, new_state)
192
+ end
193
+
176
194
# ####
177
195
# #### `Thunk`
178
196
# ####
@@ -181,8 +199,9 @@ Base.iterate(::One, ::Any) = nothing
181
199
Thunk(()->v)
182
200
A thunk is a deferred computation.
183
201
It wraps a zero argument closure that when invoked returns a differential.
202
+ `@thunk(v)` is a macro that expands into `Thunk(()->v)`.
184
203
185
- Calling that thunk, calls the wrapped closure.
204
+ Calling a thunk, calls the wrapped closure.
186
205
`extern`ing thunks applies recursively, it also externs the differial that the closure returns.
187
206
If you do not want that, then simply call the thunk
188
207
@@ -199,31 +218,87 @@ Thunk(var"##8#10"())
199
218
julia> t()()
200
219
3
201
220
```
221
+
222
+ ### When to `@thunk`?
223
+ When writing `rrule`s (and to a lesser exent `frule`s), it is important to `@thunk`
224
+ appropriately.
225
+ Propagation rule's that return multiple derivatives are not able to do all the computing themselves.
226
+ By `@thunk`ing the work required for each, they then compute only what is needed.
227
+
228
+ #### So why not thunk everything?
229
+ `@thunk` creates a closure over the expression, which (effectively) creates a `struct`
230
+ with a field for each variable used in the expression, and call overloaded.
231
+
232
+ Do not use `@thunk` if this would be equal or more work than actually evaluating the expression itself. Examples being:
233
+ - The expression wrapping something in a `struct`, such as `Adjoint(x)` or `Diagonal(x)`
234
+ - The expression being a constant
235
+ - The expression being itself a `thunk`
236
+ - The expression being from another `rrule` or `frule` (it would be `@thunk`ed if required by the defining rule already)
202
237
"""
203
- struct Thunk{F} <: AbstractDifferential
238
+ struct Thunk{F} <: AbstractThunk
204
239
f:: F
205
240
end
206
241
207
242
macro thunk (body)
208
243
return :(Thunk (() -> $ (esc (body))))
209
244
end
210
245
246
+ # have to define this here after `@thunk` and `Thunk` is defined
247
+ Base. conj (x:: AbstractThunk ) = @thunk (conj (extern (x)))
248
+
249
+
211
250
(x:: Thunk )() = x. f ()
212
251
@inline extern (x:: Thunk ) = extern (x ())
213
252
214
- Base. Broadcast . broadcastable ( x:: Thunk ) = broadcastable ( extern (x) )
253
+ Base. show (io :: IO , x:: Thunk ) = println (io, " Thunk( $( repr (x . f)) ) " )
215
254
216
- @inline function Base. iterate (x:: Thunk )
217
- externed = extern (x)
218
- element, state = iterate (externed)
219
- return element, (externed, state)
255
+ """
256
+ InplaceableThunk(val::Thunk, add!::Function)
257
+
258
+ A wrapper for a `Thunk`, that allows it to define an inplace `add!` function,
259
+ which is used internally in `accumulate!(Δ, ::InplaceableThunk)`.
260
+
261
+ `add!` should be defined such that: `ithunk.add!(Δ) = Δ .+= ithunk.val`
262
+ but it should do this more efficently than simply doing this directly.
263
+ (Otherwise one can just use a normal `Thunk`).
264
+
265
+ Most operations on an `InplaceableThunk` treat it just like a normal `Thunk`;
266
+ and destroy its inplacability.
267
+ """
268
+ struct InplaceableThunk{T<: Thunk , F} <: AbstractThunk
269
+ val:: T
270
+ add!:: F
220
271
end
221
272
222
- @inline function Base. iterate (:: Thunk , (externed, state))
223
- element, new_state = iterate (externed, state)
224
- return element, (externed, new_state)
273
+ (x:: InplaceableThunk )() = x. val ()
274
+ @inline extern (x:: InplaceableThunk ) = extern (x. val)
275
+
276
+ function Base. show (io:: IO , x:: InplaceableThunk )
277
+ println (io, " InplaceableThunk($(repr (x. val)) , $(repr (x. add!)) )" )
225
278
end
226
279
227
- Base. conj (x:: Thunk ) = @thunk (conj (extern (x)))
280
+ # The real reason we have this:
281
+ accumulate! (Δ, ∂:: InplaceableThunk ) = ∂. add! (Δ)
282
+ store! (Δ, ∂:: InplaceableThunk ) = ∂. add! ((Δ.*= false )) # zero it, then add to it.
228
283
229
- Base. show (io:: IO , x:: Thunk ) = println (io, " Thunk($(repr (x. f)) )" )
284
+ """
285
+ NO_FIELDS
286
+
287
+ Constant for the reverse-mode derivative with respect to a structure that has no fields.
288
+ The most notable use for this is for the reverse-mode derivative with respect to the
289
+ function itself, when that function is not a closure.
290
+ """
291
+ const NO_FIELDS = DNE ()
292
+
293
+ """
294
+ refine_differential(𝒟::Type, der)
295
+
296
+ Converts, if required, a differential object `der`
297
+ (e.g. a `Number`, `AbstractDifferential`, `Matrix`, etc.),
298
+ to another differential that is more suited for the domain given by the type 𝒟.
299
+ Often this will behave as the identity function on `der`.
300
+ """
301
+ function refine_differential (:: Type{<:Union{<:Real, AbstractArray{<:Real}}} , w:: Wirtinger )
302
+ return wirtinger_primal (w) + wirtinger_conjugate (w)
303
+ end
304
+ refine_differential (:: Any , der) = der # most of the time leave it alone.
0 commit comments