Skip to content

Commit 5e5faae

Browse files
committed
Add rrules for mapreduce, mapfoldr, mapfoldl
1 parent 3c364f7 commit 5e5faae

File tree

2 files changed

+47
-0
lines changed

2 files changed

+47
-0
lines changed

src/rules/mapreduce.jl

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,30 @@ function rrule(::typeof(map), f, xs...)
1515
return y, (DNERule(), ∂xs...)
1616
end
1717

18+
#####
19+
##### `mapreduce`, `mapfoldl`, `mapfoldr`
20+
#####
21+
22+
for mf in (:mapreduce, :mapfoldl, :mapfoldr)
23+
sig = :(rrule(::typeof($mf), f, op, x::AbstractArray{<:Real}))
24+
call = :($mf(f, op, x))
25+
if mf === :mapreduce
26+
insert!(sig.args, 2, Expr(:parameters, Expr(:kw, :dims, :(:))))
27+
insert!(call.args, 2, Expr(:parameters, Expr(:kw, :dims, :dims)))
28+
end
29+
body = quote
30+
y = $call
31+
∂x = Rule() do
32+
broadcast(x, ȳ) do xi, ȳi
33+
_, ∂xi = _checked_rrule(f, xi)
34+
extern(∂xi(ȳi))
35+
end
36+
end
37+
return y, (DNERule(), DNERule(), ∂x)
38+
end
39+
eval(Expr(:function, sig, body))
40+
end
41+
1842
#####
1943
##### `sum`
2044
#####

test/rules/mapreduce.jl

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,27 @@
88
rrule_test(map, ȳ, (sin, nothing), (x, vx))
99
rrule_test(map, ȳ, (+, nothing), (x, vx), (randn(rng, n), randn(rng, n)))
1010
end
11+
@testset "mapreduce" begin
12+
rng = MersenneTwister(6)
13+
n = 10
14+
x = randn(rng, n)
15+
vx = randn(rng, n)
16+
= randn(rng)
17+
rrule_test(mapreduce, ȳ, (sin, nothing), (+, nothing), (x, vx))
18+
# With keyword arguments (not yet supported in rrule_test)
19+
X = randn(rng, n, n)
20+
y, (_, _, dx) = rrule(mapreduce, abs2, +, X; dims=2)
21+
= randn(rng, size(y))
22+
x̄_ad = dx(ȳ)
23+
x̄_fd = j′vp(central_fdm(5, 1), x->mapreduce(abs2, +, x; dims=2), ȳ, X)
24+
@test x̄_ad x̄_fd atol=1e-9 rtol=1e-9
25+
end
26+
@testset "$f" for f in (mapfoldl, mapfoldr)
27+
rng = MersenneTwister(10)
28+
n = 7
29+
x = randn(rng, n)
30+
vx = randn(rng, n)
31+
= randn(rng)
32+
rrule_test(f, ȳ, (cos, nothing), (+, nothing), (x, vx))
33+
end
1134
end

0 commit comments

Comments
 (0)