@@ -149,6 +149,7 @@ Base.iterate(::One, ::Any) = nothing
149
149
# ####
150
150
# #### `AbstractThunk
151
151
# ####
152
+
152
153
abstract type AbstractThunk <: AbstractDifferential end
153
154
154
155
Base. Broadcast. broadcastable (x:: AbstractThunk ) = broadcastable (extern (x))
@@ -237,8 +238,17 @@ macro thunk(body)
237
238
return :(Thunk ($ (esc (func))))
238
239
end
239
240
241
+ """
242
+ unthunk(x)
243
+
244
+ `unthunk` removes 1 layer of thunking from an `AbstractThunk`,
245
+ and on all other types is the `identity` function.
246
+ """
247
+ unthunk (x) = x
248
+ unthunk (x:: Thunk ) = x ()
249
+
240
250
# have to define this here after `@thunk` and `Thunk` is defined
241
- Base. conj (x:: AbstractThunk ) = @thunk (conj (extern (x)))
251
+ Base. conj (x:: AbstractThunk ) = @thunk (conj (unthunk (x)))
242
252
243
253
(x:: Thunk )() = x. f ()
244
254
@inline unthunk (x:: Thunk ) = x ()
@@ -284,6 +294,73 @@ function itself, when that function is not a closure.
284
294
"""
285
295
const NO_FIELDS = DoesNotExist ()
286
296
297
+
298
+ """
299
+ Composite{P, T} <: AbstractDifferential
300
+
301
+ This type represents the differential for a `struct`/`NamedTuple`, or `Tuple`.
302
+ `P` is the the corresponding primal type that this is a differential for.
303
+
304
+ `Composite{P}` should have fields (technically properties), that match to a subset of the
305
+ fields of the primal type; and each should be a differential type matching to the primal
306
+ type of that field.
307
+ Fields of the P that are not present in the Composite are treated as `Zero`.
308
+
309
+ `T` is an implementation detail representing the backing data structure.
310
+ For Tuple it will be a Tuple, and for everything else it will be a `NamedTuple`.
311
+ It should not be passed in by user.
312
+ """
313
+ struct Composite{P, T} <: AbstractDifferential
314
+ # Note: If T is a Tuple, then P is also a Tuple
315
+ # (but potentially a different one, as it doesn't contain differentials)
316
+ backing:: T
317
+ end
318
+
319
+ function Composite {P} (; kwargs... ) where P
320
+ backing = (; kwargs... ) # construct as NamedTuple
321
+ return Composite {P, typeof(backing)} (backing)
322
+ end
323
+
324
+ function Composite {P} (args... ) where P
325
+ return Composite {P, typeof(args)} (args)
326
+ end
327
+
328
+ function Base. show (io:: IO , comp:: Composite{P} ) where P
329
+ print (io, " Composite{" )
330
+ show (io, P)
331
+ print (io, " }" )
332
+ # allow Tuple or NamedTuple `show` to do the rendering of brackets etc
333
+ show (io, backing (comp))
334
+ end
335
+
336
+ Base. convert (:: Type{<:NamedTuple} , comp:: Composite{<:Any, <:NamedTuple} ) = backing (comp)
337
+ Base. convert (:: Type{<:Tuple} , comp:: Composite{<:Any, <:Tuple} ) = backing (comp)
338
+
339
+ Base. getindex (comp:: Composite , idx) = getindex (backing (comp), idx)
340
+ Base. getproperty (comp:: Composite , idx:: Int ) = getproperty (backing (comp), idx) # for Tuple
341
+ Base. getproperty (comp:: Composite , idx:: Symbol ) = getproperty (backing (comp), idx)
342
+ Base. propertynames (comp:: Composite ) = propertynames (backing (comp))
343
+
344
+ Base. iterate (comp:: Composite , args... ) = iterate (backing (comp), args... )
345
+ Base. length (comp:: Composite ) = length (backing (comp))
346
+ Base. eltype (:: Type{<:Composite{<:Any, T}} ) where T = eltype (T)
347
+
348
+ function Base. map (f, comp:: Composite{P, <:Tuple} ) where P
349
+ vals:: Tuple = map (f, backing (comp))
350
+ return Composite {P, typeof(vals)} (vals)
351
+ end
352
+ function Base. map (f, comp:: Composite{P, <:NamedTuple{L}} ) where {P, L}
353
+ vals = map (f, Tuple (backing (comp)))
354
+ named_vals = NamedTuple {L, typeof(vals)} (vals)
355
+ return Composite {P, typeof(named_vals)} (named_vals)
356
+ end
357
+
358
+ Base. conj (comp:: Composite ) = map (conj, comp)
359
+
360
+ extern (comp:: Composite ) = backing (map (extern, comp)) # gives a NamedTuple or Tuple
361
+
362
+ #= =============================================================================#
363
+
287
364
"""
288
365
refine_differential(𝒟::Type, der)
289
366
0 commit comments