@@ -149,6 +149,7 @@ Base.iterate(::One, ::Any) = nothing
149
149
# ####
150
150
# #### `AbstractThunk
151
151
# ####
152
+
152
153
abstract type AbstractThunk <: AbstractDifferential end
153
154
154
155
Base. Broadcast. broadcastable (x:: AbstractThunk ) = broadcastable (extern (x))
@@ -237,8 +238,17 @@ macro thunk(body)
237
238
return :(Thunk ($ (esc (func))))
238
239
end
239
240
241
+ """
242
+ unthunk(x)
243
+
244
+ `unthunk` removes 1 layer of thunking from an `AbstractThunk`,
245
+ and on all other types is the `identity` function.
246
+ """
247
+ unthunk (x) = x
248
+ unthunk (x:: Thunk ) = x ()
249
+
240
250
# have to define this here after `@thunk` and `Thunk` is defined
241
- Base. conj (x:: AbstractThunk ) = @thunk (conj (extern (x)))
251
+ Base. conj (x:: AbstractThunk ) = @thunk (conj (unthunk (x)))
242
252
243
253
(x:: Thunk )() = x. f ()
244
254
@inline unthunk (x:: Thunk ) = x ()
@@ -286,61 +296,68 @@ const NO_FIELDS = DoesNotExist()
286
296
287
297
288
298
"""
289
- Composite{Primal , T} <: AbstractDifferential
299
+ Composite{P , T} <: AbstractDifferential
290
300
291
301
This type represents the differential for a `struct`/`NamedTuple`, or `Tuple`.
292
- `Primal ` is the the corresponding primal type that this is a differential for.
302
+ `P ` is the the corresponding primal type that this is a differential for.
293
303
294
- `Composite{Primal }` should have fields (technically properties), that match to a subset of the
304
+ `Composite{P }` should have fields (technically properties), that match to a subset of the
295
305
fields of the primal type; and each should be a differential type matching to the primal
296
306
type of that field.
297
- Fields of the Primal that are not present in the Composite are treated as `Zero`.
307
+ Fields of the P that are not present in the Composite are treated as `Zero`.
298
308
299
- `T` is an implementation detail representing the backing datastructure .
309
+ `T` is an implementation detail representing the backing data structure .
300
310
For Tuple it will be a Tuple, and for everything else it will be a `NamedTuple`.
301
311
It should not be passed in by user.
302
312
"""
303
- struct Composite{Primal, T} <: AbstractDifferential
313
+ struct Composite{P, T} <: AbstractDifferential
314
+ # Note: If T is a Tuple, then P is also a Tuple
315
+ # (but potentially a different one, as it doesn't contain differentials)
304
316
backing:: T
305
317
end
306
318
307
-
308
- function Composite {Primal} (;kwargs... ) where Primal
309
- backing = (; kwargs... )
310
- return Composite {Primal, typeof(backing)} (backing)
319
+ function Composite {P} (; kwargs... ) where P
320
+ backing = (; kwargs... ) # construct as NamedTuple
321
+ return Composite {P, typeof(backing)} (backing)
311
322
end
312
323
313
- function Composite {Primal } (args... ) where Primal
314
- return Composite {Primal , typeof(args)} (args)
324
+ function Composite {P } (args... ) where P
325
+ return Composite {P , typeof(args)} (args)
315
326
end
316
327
317
- function Base. show (io:: IO , comp:: Composite{Primal} )
328
+ function Base. show (io:: IO , comp:: Composite{P} ) where P
318
329
print (io, " Composite{" )
319
- show (io, Primal )
330
+ show (io, P )
320
331
print (io, " }" )
321
332
# allow Tuple or NamedTuple `show` to do the rendering of brackets etc
322
- show (io, comp . backing)
333
+ show (io, backing (comp) )
323
334
end
324
335
325
- # TODO think about this, for if we are missing fields
326
- # Base.convert(::Type{Primal}, comp::Composite{Primal})
327
- Base. convert (:: Type{<:NamedTuple} , comp:: Composite{<:Any, <:NamedTuple} ) = comp. backing
328
- Base. convert (:: Type{<:Tuple} , comp:: Composite{<:Any, <:Tuple} ) = comp. backing
336
+ Base. convert (:: Type{<:NamedTuple} , comp:: Composite{<:Any, <:NamedTuple} ) = backing (comp)
337
+ Base. convert (:: Type{<:Tuple} , comp:: Composite{<:Any, <:Tuple} ) = backing (comp)
338
+
339
+ Base. getindex (comp:: Composite , idx) = getindex (backing (comp), idx)
340
+ Base. getproperty (comp:: Composite , idx:: Int ) = getproperty (backing (comp), idx) # for Tuple
341
+ Base. getproperty (comp:: Composite , idx:: Symbol ) = getproperty (backing (comp), idx)
342
+ Base. propertynames (comp:: Composite ) = propertynames (backing (comp))
329
343
330
- Base. getindex (comp:: Composite , idx) = getindex (comp. backing)
331
- Base. getproperty (comp:: Composite , idx) = getproperty (comp. backing, idx)
332
- Base. propertynames (comp:: Composite ) = propertynames (comp. backing)
333
- Base. iterate (comp:: Compositem , args... ) = iterate (comp. backing, args... )
334
- Base. length (comp:: Composite ) = length (comp. backing)
344
+ Base. iterate (comp:: Composite , args... ) = iterate (backing (comp), args... )
345
+ Base. length (comp:: Composite ) = length (backing (comp))
346
+ Base. eltype (:: Type{<:Composite{<:Any, T}} ) where T = eltype (T)
335
347
336
- map (f, comp:: Composite{Primal, <:Tuple} ) where Primal = Composite {Primal} (map (f, comp. backing))
337
- function map (f, comp:: Composite{Primal, <:NamedTuple{L}} ) where {Primal, L}
338
- vals = map (f, Tuple (comp. backing))
348
+ function Base. map (f, comp:: Composite{P, <:Tuple} ) where P
349
+ vals:: Tuple = map (f, backing (comp))
350
+ return Composite {P, typeof(vals)} (vals)
351
+ end
352
+ function Base. map (f, comp:: Composite{P, <:NamedTuple{L}} ) where {P, L}
353
+ vals = map (f, Tuple (backing (comp)))
339
354
named_vals = NamedTuple {L, typeof(vals)} (vals)
340
- return Composite {Primal } (named_vals)
355
+ return Composite {P, typeof(named_vals) } (named_vals)
341
356
end
342
357
343
- Base. conj (comp:: Composite{Primal} ) = map (conj, comp)
358
+ Base. conj (comp:: Composite ) = map (conj, comp)
359
+
360
+ extern (comp:: Composite ) = backing (map (extern, comp)) # gives a NamedTuple or Tuple
344
361
345
362
#= =============================================================================#
346
363
0 commit comments