Skip to content

Commit 44f1a5e

Browse files
committed
Rewrite as a generated functor
1 parent d40e282 commit 44f1a5e

File tree

2 files changed

+17
-13
lines changed

2 files changed

+17
-13
lines changed

src/rulesets/Base/base.jl

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -295,17 +295,21 @@ end
295295
####
296296
#### merge
297297
####
298-
299-
function rrule(::typeof(merge), nt1::NamedTuple{F1}, nt2::NamedTuple{F2}) where {F1,F2}
300-
y = merge(nt1, nt2)
301-
function merge_pullback(dy)
302-
dnt1 = Tangent{typeof(nt1)}(;
303-
(f1 => (f1 in F2 ? ZeroTangent() : getproperty(dy, f1)) for f1 in F1)...
304-
)
305-
dnt2 = Tangent{typeof(nt2)}(; (f2 => getproperty(dy, f2) for f2 in F2)...)
298+
# need to work around inability to return closures from generated functions
299+
struct MergePullback{T1, T2}
300+
end
301+
(this::MergePullback)(dy::AbstractThunk) = this(unthunk(dy))
302+
(::MergePullback)(x::AbstractZero) = (NoTangent(), x, x)
303+
@generated function(::MergePullback{T1,T2})(dy::Tangent) where {F1,T1<:NamedTuple{F1},F2,T2<:NamedTuple{F2}}
304+
_getproperty_kwexpr(key) = :($key = getproperty(dy, $(Meta.quot(key))))
305+
quote
306+
dnt1 = Tangent{T1}(; $(map(_getproperty_kwexpr, setdiff(F1, F2))...))
307+
dnt2 = Tangent{T2}(; $(map(_getproperty_kwexpr, F2)...))
306308
return (NoTangent(), dnt1, dnt2)
307309
end
308-
merge_pullback(dy::AbstractThunk) = merge_pullback(unthunk(dy))
309-
merge_pullback(x::AbstractZero) = (NoTangent(), x, x)
310-
return y, merge_pullback
310+
end
311+
312+
function rrule(::typeof(merge), nt1::T1, nt2::T2) where {T1<:NamedTuple, T2<:NamedTuple}
313+
y = merge(nt1, nt2)
314+
return y, MergePullback{T1,T2}()
311315
end

test/rulesets/Base/base.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -259,7 +259,7 @@ end
259259
end
260260

261261
@testset "merge NamedTuple" begin
262-
test_rrule(merge, (; a=1.0), (; b=2.0); check_inferred=false)
263-
test_rrule(merge, (; a=1.0), (; a=2.0); check_inferred=false)
262+
test_rrule(merge, (;a=1.0), (;b=2.0))
263+
test_rrule(merge, (;a=1.0), (;a=2.0))
264264
end
265265
end

0 commit comments

Comments
 (0)