@@ -295,17 +295,21 @@ end
295
295
# ###
296
296
# ### merge
297
297
# ###
298
-
299
- function rrule (:: typeof (merge), nt1:: NamedTuple{F1} , nt2:: NamedTuple{F2} ) where {F1,F2}
300
- y = merge (nt1, nt2)
301
- function merge_pullback (dy)
302
- dnt1 = Tangent {typeof(nt1)} (;
303
- (f1 => (f1 in F2 ? ZeroTangent () : getproperty (dy, f1)) for f1 in F1). ..
304
- )
305
- dnt2 = Tangent {typeof(nt2)} (; (f2 => getproperty (dy, f2) for f2 in F2). .. )
298
+ # need to work around inability to return closures from generated functions
299
+ struct MergePullback{T1, T2}
300
+ end
301
+ (this:: MergePullback )(dy:: AbstractThunk ) = this (unthunk (dy))
302
+ (:: MergePullback )(x:: AbstractZero ) = (NoTangent (), x, x)
303
+ @generated function (:: MergePullback{T1,T2} )(dy:: Tangent ) where {F1,T1<: NamedTuple{F1} ,F2,T2<: NamedTuple{F2} }
304
+ _getproperty_kwexpr (key) = :($ key = getproperty (dy, $ (Meta. quot (key))))
305
+ quote
306
+ dnt1 = Tangent {T1} (; $ (map (_getproperty_kwexpr, setdiff (F1, F2))... ))
307
+ dnt2 = Tangent {T2} (; $ (map (_getproperty_kwexpr, F2)... ))
306
308
return (NoTangent (), dnt1, dnt2)
307
309
end
308
- merge_pullback (dy:: AbstractThunk ) = merge_pullback (unthunk (dy))
309
- merge_pullback (x:: AbstractZero ) = (NoTangent (), x, x)
310
- return y, merge_pullback
310
+ end
311
+
312
+ function rrule (:: typeof (merge), nt1:: T1 , nt2:: T2 ) where {T1<: NamedTuple , T2<: NamedTuple }
313
+ y = merge (nt1, nt2)
314
+ return y, MergePullback {T1,T2} ()
311
315
end
0 commit comments