Skip to content

Commit fa7f884

Browse files
authored
Merge pull request #795 from JuliaDiff/ox/merge
Add rrule for NamedTuple merge
2 parents 9dd39bd + 49d5ae7 commit fa7f884

File tree

2 files changed

+28
-0
lines changed

2 files changed

+28
-0
lines changed

src/rulesets/Base/base.jl

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -291,3 +291,26 @@ function rrule(config::RuleConfig{>:HasReverseMode}, ::typeof(task_local_storage
291291
end
292292
return y, task_local_storage_pullback
293293
end
294+
295+
####
296+
#### merge
297+
####
298+
# need to work around inability to return closures from generated functions
299+
struct MergePullback{T1,T2} end
300+
(this::MergePullback)(dy::AbstractThunk) = this(unthunk(dy))
301+
(::MergePullback)(x::AbstractZero) = (NoTangent(), x, x)
302+
@generated function (::MergePullback{T1,T2})(
303+
dy::Tangent
304+
) where {F1,T1<:NamedTuple{F1},F2,T2<:NamedTuple{F2}}
305+
_getproperty_kwexpr(key) = :($key = getproperty(dy, $(Meta.quot(key))))
306+
quote
307+
dnt1 = Tangent{T1}(; $(map(_getproperty_kwexpr, setdiff(F1, F2))...))
308+
dnt2 = Tangent{T2}(; $(map(_getproperty_kwexpr, F2)...))
309+
return (NoTangent(), dnt1, dnt2)
310+
end
311+
end
312+
313+
function rrule(::typeof(merge), nt1::T1, nt2::T2) where {T1<:NamedTuple,T2<:NamedTuple}
314+
y = merge(nt1, nt2)
315+
return y, MergePullback{T1,T2}()
316+
end

test/rulesets/Base/base.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -258,4 +258,9 @@ end
258258
test_rrule(map, Multiplier(4.5), (6.7, 8.9), (0.1, 0.2, 0.3), check_inferred=false)
259259
end
260260
end
261+
262+
@testset "merge NamedTuple" begin
263+
test_rrule(merge, (; a=1.0), (; b=2.0))
264+
test_rrule(merge, (; a=1.0), (; a=2.0))
265+
end
261266
end

0 commit comments

Comments
 (0)