Skip to content

Commit a05ebb2

Browse files
committed
Add rrule for NamedTuple merge
1 parent 2c6621c commit a05ebb2

File tree

2 files changed

+27
-0
lines changed

2 files changed

+27
-0
lines changed

src/rulesets/Base/base.jl

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -291,3 +291,25 @@ function rrule(config::RuleConfig{>:HasReverseMode}, ::typeof(task_local_storage
291291
end
292292
return y, task_local_storage_pullback
293293
end
294+
295+
296+
####
297+
#### merge
298+
####
299+
300+
function rrule(::typeof(merge), nt1::NamedTuple{F1}, nt2::NamedTuple{F2}) where {F1, F2}
301+
y = merge(nt1, nt2)
302+
function merge_pullback(dy)
303+
dnt1 = Tangent{typeof(nt1)}(;
304+
(f1 => (f1 in F2 ? ZeroTangent() : getproperty(dy, f1)) for f1 in F1)...
305+
)
306+
dnt2 = Tangent{typeof(nt2)}(;
307+
(f2 => getproperty(dy, f2) for f2 in F2)...
308+
)
309+
return (NoTangent(), dnt1, dnt2)
310+
end
311+
merge_pullback(dy::AbstractThunk) = merge_pullback(unthunk(dy))
312+
merge_pullback(x::AbstractZero) = (NoTangent(), x, x)
313+
314+
return y, merge_pullback
315+
end

test/rulesets/Base/base.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -257,4 +257,9 @@ end
257257
test_rrule(map, Multiplier(4.5), (6.7, 8.9), (0.1, 0.2, 0.3), check_inferred=false)
258258
end
259259
end
260+
261+
@testset "merge NamedTuple" begin
262+
test_rrule(merge, (;a=1.0), (;b=2.0), check_inferred=false)
263+
test_rrule(merge, (;a=1.0), (;a=2.0), check_inferred=false)
264+
end
260265
end

0 commit comments

Comments
 (0)