Skip to content

Commit fc555b1

Browse files
committed
Specialize rrules for sum, add rrule for identity
This includes some of the specialized methods for `sum` from Nabla.
1 parent 5e5faae commit fc555b1

File tree

5 files changed

+59
-17
lines changed

5 files changed

+59
-17
lines changed

src/rules/base.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,3 +77,7 @@
7777
frule(::typeof(*), x, y) = x * y, Rule((Δx, Δy) -> Δx * y + x * Δy)
7878

7979
rrule(::typeof(*), x, y) = x * y, (Rule(ΔΩ -> ΔΩ * y'), Rule(ΔΩ -> x' * ΔΩ))
80+
81+
frule(::typeof(identity), x) = x, Rule(identity)
82+
83+
rrule(::typeof(identity), x) = x, Rule(identity)

src/rules/mapreduce.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,3 +46,19 @@ end
4646
frule(::typeof(sum), x) = (sum(x), Rule(sum))
4747

4848
rrule(::typeof(sum), x) = (sum(x), Rule(cast))
49+
50+
function rrule(::typeof(sum), f, x::AbstractArray{<:Real}; dims=:)
51+
y, (_, _, ∂x) = rrule(mapreduce, f, Base.add_sum, x; dims=dims)
52+
return y, (DNERule(), ∂x)
53+
end
54+
55+
function rrule(::typeof(sum), x::AbstractArray{<:Real}; dims=:)
56+
y, (_, ∂x) = rrule(sum, identity, x; dims=dims)
57+
return y, ∂x
58+
end
59+
60+
function rrule(::typeof(sum), ::typeof(abs2), x::AbstractArray{<:Real}; dims=:)
61+
y = sum(abs2, x; dims=dims)
62+
∂x = Rule(ȳ -> 2.* x)
63+
return y, (DNERule(), ∂x)
64+
end

test/rules/base.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,5 +97,11 @@ end
9797
@test dy === x / h * cy.value[2]
9898
end
9999
end
100+
@testset "identity" begin
101+
rng = MersenneTwister(1)
102+
n = 4
103+
rrule_test(identity, randn(rng), (randn(rng), randn(rng)))
104+
rrule_test(identity, randn(rng, 4), (randn(rng, 4), randn(rng, 4)))
105+
end
100106
end
101107
# TODO: Non-trig stuff

test/rules/linalg/dense.jl

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -4,23 +4,6 @@ function generate_well_conditioned_matrix(rng, N)
44
end
55

66
@testset "linalg" begin
7-
@testset "sum" begin
8-
@testset "Vector" begin
9-
rng, M = MersenneTwister(123456), 3
10-
frule_test(sum, (randn(rng, M), randn(rng, M)))
11-
rrule_test(sum, randn(rng), (randn(rng, M), randn(rng, M)))
12-
end
13-
@testset "Matrix" begin
14-
rng, M, N = MersenneTwister(123456), 3, 4
15-
frule_test(sum, (randn(rng, M, N), randn(rng, M, N)))
16-
rrule_test(sum, randn(rng), (randn(rng, M, N), randn(rng, M, N)))
17-
end
18-
@testset "Array{T, 3}" begin
19-
rng, M, N, P = MersenneTwister(123456), 3, 7, 11
20-
frule_test(sum, (randn(rng, M, N, P), randn(rng, M, N, P)))
21-
rrule_test(sum, randn(rng), (randn(rng, M, N, P), randn(rng, M, N, P)))
22-
end
23-
end
247
@testset "dot" begin
258
@testset "Vector" begin
269
rng, M = MersenneTwister(123456), 3

test/rules/mapreduce.jl

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,4 +31,37 @@
3131
= randn(rng)
3232
rrule_test(f, ȳ, (cos, nothing), (+, nothing), (x, vx))
3333
end
34+
@testset "sum" begin
35+
@testset "Vector" begin
36+
rng, M = MersenneTwister(123456), 3
37+
frule_test(sum, (randn(rng, M), randn(rng, M)))
38+
rrule_test(sum, randn(rng), (randn(rng, M), randn(rng, M)))
39+
end
40+
@testset "Matrix" begin
41+
rng, M, N = MersenneTwister(123456), 3, 4
42+
frule_test(sum, (randn(rng, M, N), randn(rng, M, N)))
43+
rrule_test(sum, randn(rng), (randn(rng, M, N), randn(rng, M, N)))
44+
end
45+
@testset "Array{T, 3}" begin
46+
rng, M, N, P = MersenneTwister(123456), 3, 7, 11
47+
frule_test(sum, (randn(rng, M, N, P), randn(rng, M, N, P)))
48+
rrule_test(sum, randn(rng), (randn(rng, M, N, P), randn(rng, M, N, P)))
49+
end
50+
@testset "function argument" begin
51+
rng = MersenneTwister(1)
52+
n = 8
53+
rrule_test(sum, randn(rng), (cos, nothing), (randn(rng, n), randn(rng, n)))
54+
rrule_test(sum, randn(rng), (abs2, nothing), (randn(rng, n), randn(rng, n)))
55+
end
56+
@testset "keyword arguments" begin
57+
rng = MersenneTwister(33)
58+
n = 4
59+
X = randn(rng, n, n)
60+
y, dX = rrule(sum, X; dims=2)
61+
= randn(rng, size(y))
62+
x̄_ad = dX(ȳ)
63+
x̄_fd = j′vp(central_fdm(5, 1), x->sum(x, dims=2), ȳ, X)
64+
@test x̄_ad x̄_fd atol=1e-9 rtol=1e-9
65+
end
66+
end
3467
end

0 commit comments

Comments
 (0)