Skip to content

Commit 91c5d92

Browse files
authored
Merge pull request JuliaDiff#59 from JuliaDiff/aa/mapreduce
Add a few more reduction rrules
2 parents 3c364f7 + 5fb2923 commit 91c5d92

File tree

8 files changed

+142
-18
lines changed

8 files changed

+142
-18
lines changed

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ Cassette = "7057c7e9-c182-5462-911a-8362d720325c"
77
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
88
NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3"
99
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
10+
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1011

1112
[compat]
1213
Cassette = "^0.2"

src/ChainRules.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ module ChainRules
33
using Cassette
44
using LinearAlgebra
55
using LinearAlgebra.BLAS
6+
using Statistics
67
using Base.Broadcast: materialize, materialize!, broadcasted, Broadcasted, broadcastable
78

89
if VERSION < v"1.3.0-DEV.142"

src/rules/base.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,3 +77,7 @@
7777
frule(::typeof(*), x, y) = x * y, Rule((Δx, Δy) -> Δx * y + x * Δy)
7878

7979
rrule(::typeof(*), x, y) = x * y, (Rule(ΔΩ -> ΔΩ * y'), Rule(ΔΩ -> x' * ΔΩ))
80+
81+
frule(::typeof(identity), x) = x, Rule(identity)
82+
83+
rrule(::typeof(identity), x) = x, Rule(identity)

src/rules/mapreduce.jl

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,72 @@ function rrule(::typeof(map), f, xs...)
1515
return y, (DNERule(), ∂xs...)
1616
end
1717

18+
#####
19+
##### `mapreduce`, `mapfoldl`, `mapfoldr`
20+
#####
21+
22+
for mf in (:mapreduce, :mapfoldl, :mapfoldr)
23+
sig = :(rrule(::typeof($mf), f, op, x::AbstractArray{<:Real}))
24+
call = :($mf(f, op, x))
25+
if mf === :mapreduce
26+
insert!(sig.args, 2, Expr(:parameters, Expr(:kw, :dims, :(:))))
27+
insert!(call.args, 2, Expr(:parameters, Expr(:kw, :dims, :dims)))
28+
end
29+
body = quote
30+
y = $call
31+
∂x = Rule() do
32+
broadcast(x, ȳ) do xi, ȳi
33+
_, ∂xi = _checked_rrule(f, xi)
34+
extern(∂xi(ȳi))
35+
end
36+
end
37+
return y, (DNERule(), DNERule(), ∂x)
38+
end
39+
eval(Expr(:function, sig, body))
40+
end
41+
1842
#####
1943
##### `sum`
2044
#####
2145

2246
frule(::typeof(sum), x) = (sum(x), Rule(sum))
2347

2448
rrule(::typeof(sum), x) = (sum(x), Rule(cast))
49+
50+
function rrule(::typeof(sum), f, x::AbstractArray{<:Real}; dims=:)
51+
y, (_, _, ∂x) = rrule(mapreduce, f, Base.add_sum, x; dims=dims)
52+
return y, (DNERule(), ∂x)
53+
end
54+
55+
function rrule(::typeof(sum), x::AbstractArray{<:Real}; dims=:)
56+
y, (_, ∂x) = rrule(sum, identity, x; dims=dims)
57+
return y, ∂x
58+
end
59+
60+
function rrule(::typeof(sum), ::typeof(abs2), x::AbstractArray{<:Real}; dims=:)
61+
y = sum(abs2, x; dims=dims)
62+
∂x = Rule(ȳ -> 2.* x)
63+
return y, (DNERule(), ∂x)
64+
end
65+
66+
#####
67+
##### `mean`
68+
#####
69+
70+
_denom(x, dims::Colon) = length(x)
71+
_denom(x, dims::Integer) = size(x, dims)
72+
_denom(x, dims) = mapreduce(i->size(x, i), Base.mul_prod, unique(dims), init=1)
73+
74+
# TODO: We have `mean(f, x; dims)` as of 1.3.0-DEV.36
75+
76+
function rrule(::typeof(mean), x::AbstractArray{<:Real}; dims=:)
77+
_, dx = rrule(sum, x; dims=dims)
78+
n = _denom(x, dims)
79+
return mean(x; dims=dims), Rule(ȳ -> dx(ȳ) / n)
80+
end
81+
82+
function rrule(::typeof(mean), f, x::AbstractArray{<:Real})
83+
_, (_, dx) = rrule(sum, f, x)
84+
n = _denom(x, :)
85+
return mean(f, x), (DNERule(), Rule(ȳ -> dx(ȳ) / n))
86+
end

test/rules/base.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,5 +97,11 @@ end
9797
@test dy === x / h * cy.value[2]
9898
end
9999
end
100+
@testset "identity" begin
101+
rng = MersenneTwister(1)
102+
n = 4
103+
rrule_test(identity, randn(rng), (randn(rng), randn(rng)))
104+
rrule_test(identity, randn(rng, 4), (randn(rng, 4), randn(rng, 4)))
105+
end
100106
end
101107
# TODO: Non-trig stuff

test/rules/linalg/dense.jl

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -4,23 +4,6 @@ function generate_well_conditioned_matrix(rng, N)
44
end
55

66
@testset "linalg" begin
7-
@testset "sum" begin
8-
@testset "Vector" begin
9-
rng, M = MersenneTwister(123456), 3
10-
frule_test(sum, (randn(rng, M), randn(rng, M)))
11-
rrule_test(sum, randn(rng), (randn(rng, M), randn(rng, M)))
12-
end
13-
@testset "Matrix" begin
14-
rng, M, N = MersenneTwister(123456), 3, 4
15-
frule_test(sum, (randn(rng, M, N), randn(rng, M, N)))
16-
rrule_test(sum, randn(rng), (randn(rng, M, N), randn(rng, M, N)))
17-
end
18-
@testset "Array{T, 3}" begin
19-
rng, M, N, P = MersenneTwister(123456), 3, 7, 11
20-
frule_test(sum, (randn(rng, M, N, P), randn(rng, M, N, P)))
21-
rrule_test(sum, randn(rng), (randn(rng, M, N, P), randn(rng, M, N, P)))
22-
end
23-
end
247
@testset "dot" begin
258
@testset "Vector" begin
269
rng, M = MersenneTwister(123456), 3

test/rules/mapreduce.jl

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,71 @@
88
rrule_test(map, ȳ, (sin, nothing), (x, vx))
99
rrule_test(map, ȳ, (+, nothing), (x, vx), (randn(rng, n), randn(rng, n)))
1010
end
11+
@testset "mapreduce" begin
12+
rng = MersenneTwister(6)
13+
n = 10
14+
x = randn(rng, n)
15+
vx = randn(rng, n)
16+
= randn(rng)
17+
rrule_test(mapreduce, ȳ, (sin, nothing), (+, nothing), (x, vx))
18+
# With keyword arguments (not yet supported in rrule_test)
19+
X = randn(rng, n, n)
20+
y, (_, _, dx) = rrule(mapreduce, abs2, +, X; dims=2)
21+
= randn(rng, size(y))
22+
x̄_ad = dx(ȳ)
23+
x̄_fd = j′vp(central_fdm(5, 1), x->mapreduce(abs2, +, x; dims=2), ȳ, X)
24+
@test x̄_ad x̄_fd atol=1e-9 rtol=1e-9
25+
end
26+
@testset "$f" for f in (mapfoldl, mapfoldr)
27+
rng = MersenneTwister(10)
28+
n = 7
29+
x = randn(rng, n)
30+
vx = randn(rng, n)
31+
= randn(rng)
32+
rrule_test(f, ȳ, (cos, nothing), (+, nothing), (x, vx))
33+
end
34+
@testset "sum" begin
35+
@testset "Vector" begin
36+
rng, M = MersenneTwister(123456), 3
37+
frule_test(sum, (randn(rng, M), randn(rng, M)))
38+
rrule_test(sum, randn(rng), (randn(rng, M), randn(rng, M)))
39+
end
40+
@testset "Matrix" begin
41+
rng, M, N = MersenneTwister(123456), 3, 4
42+
frule_test(sum, (randn(rng, M, N), randn(rng, M, N)))
43+
rrule_test(sum, randn(rng), (randn(rng, M, N), randn(rng, M, N)))
44+
end
45+
@testset "Array{T, 3}" begin
46+
rng, M, N, P = MersenneTwister(123456), 3, 7, 11
47+
frule_test(sum, (randn(rng, M, N, P), randn(rng, M, N, P)))
48+
rrule_test(sum, randn(rng), (randn(rng, M, N, P), randn(rng, M, N, P)))
49+
end
50+
@testset "function argument" begin
51+
rng = MersenneTwister(1)
52+
n = 8
53+
rrule_test(sum, randn(rng), (cos, nothing), (randn(rng, n), randn(rng, n)))
54+
rrule_test(sum, randn(rng), (abs2, nothing), (randn(rng, n), randn(rng, n)))
55+
end
56+
@testset "keyword arguments" begin
57+
rng = MersenneTwister(33)
58+
n = 4
59+
X = randn(rng, n, n)
60+
y, dX = rrule(sum, X; dims=2)
61+
= randn(rng, size(y))
62+
x̄_ad = dX(ȳ)
63+
x̄_fd = j′vp(central_fdm(5, 1), x->sum(x, dims=2), ȳ, X)
64+
@test x̄_ad x̄_fd atol=1e-9 rtol=1e-9
65+
end
66+
end
67+
@testset "mean" begin
68+
rng = MersenneTwister(999)
69+
n = 9
70+
rrule_test(mean, randn(rng), (abs2, nothing), (randn(rng, n), randn(rng, n)))
71+
X = randn(rng, n, n)
72+
y, dX = rrule(mean, X; dims=1)
73+
= randn(rng, size(y))
74+
X̄_ad = dX(ȳ)
75+
X̄_fd = j′vp(central_fdm(5, 1), x->mean(x, dims=1), ȳ, X)
76+
@test X̄_ad X̄_fd rtol=1e-9 atol=1e-9
77+
end
1178
end

test/runtests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# TODO: more tests!
22

3-
using ChainRules, Test, FDM, LinearAlgebra, LinearAlgebra.BLAS, Random
3+
using ChainRules, Test, FDM, LinearAlgebra, LinearAlgebra.BLAS, Random, Statistics
44
using ChainRules: extern, accumulate, accumulate!, store!, @scalar_rule,
55
Wirtinger, wirtinger_primal, wirtinger_conjugate, add_wirtinger, mul_wirtinger,
66
Zero, add_zero, mul_zero, One, add_one, mul_one, Casted, cast, add_casted, mul_casted,

0 commit comments

Comments
 (0)