287
287
# Since this works like a zero-array in broadcasting, it should also accept a number:
288
288
(project:: ProjectTo{<:Tangent{<:Ref}} )(dx:: Number ) = project (Ref (dx))
289
289
290
- # Tuple
290
+ # Tuple and NamedTuple
291
291
function ProjectTo (x:: Tuple )
292
292
elements = map (ProjectTo, x)
293
293
if elements isa NTuple{<: Any ,ProjectTo{<: AbstractZero }}
@@ -296,10 +296,22 @@ function ProjectTo(x::Tuple)
296
296
return ProjectTo {Tangent{typeof(x)}} (; elements= elements)
297
297
end
298
298
end
299
+ function ProjectTo (x:: NamedTuple )
300
+ elements = map (ProjectTo, x)
301
+ if Tuple (elements) isa NTuple{<: Any ,ProjectTo{<: AbstractZero }}
302
+ return ProjectTo {NoTangent} ()
303
+ else
304
+ return ProjectTo {Tangent{typeof(x)}} (; elements... )
305
+ end
306
+ end
307
+
299
308
# This method means that projection is re-applied to the contents of a Tangent.
300
309
# We're not entirely sure whether this is every necessary; but it should be safe,
301
310
# and should often compile away:
302
- (project:: ProjectTo{<:Tangent{<:Tuple}} )(dx:: Tangent ) = project (backing (dx))
311
+ function (project:: ProjectTo{<:Tangent{<:Union{Tuple,NamedTuple}}} )(dx:: Tangent )
312
+ return project (backing (dx))
313
+ end
314
+
303
315
function (project:: ProjectTo{<:Tangent{<:Tuple}} )(dx:: Tuple )
304
316
len = length (project. elements)
305
317
if length (dx) != len
@@ -310,6 +322,45 @@ function (project::ProjectTo{<:Tangent{<:Tuple}})(dx::Tuple)
310
322
dy = map ((f, x) -> f (x), project. elements, dx)
311
323
return project_type (project)(dy... )
312
324
end
325
+ function (project:: ProjectTo{<:Tangent{<:NamedTuple}} )(dx:: NamedTuple )
326
+ dy = _project_namedtuple (backing (project), dx)
327
+ return project_type (project)(; dy... )
328
+ end
329
+
330
+ # Diffractor returns not necessarily a named tuple with all keys and of the same order as
331
+ # the projector
332
+ # Thus we can't use `map`
333
+ function _project_namedtuple (f:: NamedTuple{fn,ft} , x:: NamedTuple{xn,xt} ) where {fn,ft,xn,xt}
334
+ if @generated
335
+ vals = Any[
336
+ if xn[i] in fn
337
+ :(getfield (f, $ (QuoteNode (xn[i])))(getfield (x, $ (QuoteNode (xn[i])))))
338
+ else
339
+ throw (
340
+ ArgumentError (
341
+ " named tuple with keys(x) == $fn cannot have a gradient with key $(xn[i]) " ,
342
+ ),
343
+ )
344
+ end for i in 1 : length (xn)
345
+ ]
346
+ :(NamedTuple {$xn} (($ (vals... ),)))
347
+ else
348
+ vals = ntuple (Val (length (xn))) do i
349
+ name = xn[i]
350
+ if name in fn
351
+ getfield (f, name)(getfield (x, name))
352
+ else
353
+ throw (
354
+ ArgumentError (
355
+ " named tuple with keys(x) == $fn cannot have a gradient with key $(xn[i]) " ,
356
+ ),
357
+ )
358
+ end
359
+ end
360
+ NamedTuple {xn} (vals)
361
+ end
362
+ end
363
+
313
364
function (project:: ProjectTo{<:Tangent{<:Tuple}} )(dx:: AbstractArray )
314
365
for d in 1 : ndims (dx)
315
366
if size (dx, d) != get (length (project. elements), d, 1 )
0 commit comments