Skip to content

Commit 0d55c54

Browse files
authored
Merge pull request #494 from mcabbott/sum_second
Rule for `sum` allowing 2nd derivatives
2 parents 7593339 + 66612b3 commit 0d55c54

File tree

4 files changed

+83
-18
lines changed

4 files changed

+83
-18
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ChainRules"
22
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
3-
version = "1.5.0"
3+
version = "1.5.1"
44

55
[deps]
66
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

src/rulesets/Base/mapreduce.jl

Lines changed: 32 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,48 @@
11
#####
2-
##### `sum`
2+
##### `sum(x)`
33
#####
44

55
function frule((_, ẋ), ::typeof(sum), x; dims=:)
66
return sum(x; dims=dims), sum(ẋ; dims=dims)
77
end
88

9-
function rrule(::typeof(sum), x::AbstractArray{T}; dims=:) where {T<:Number}
9+
function rrule(::typeof(sum), x::AbstractArray; dims=:)
10+
project = ProjectTo(x)
1011
y = sum(x; dims=dims)
11-
function sum_pullback(ȳ)
12-
# broadcasting the two works out the size no-matter `dims`
13-
= InplaceableThunk(
14-
x -> x .+= ȳ,
15-
@thunk(broadcast(lasttuple, x, ȳ)),
12+
function sum_pullback(dy_raw)
13+
dy = unthunk(dy_raw)
14+
x_thunk = InplaceableThunk(
15+
# Protect `dy` from broadcasting, for when `x` is an array of arrays:
16+
dx -> dx .+= (dims isa Colon ? Ref(dy) : dy),
17+
@thunk project(_unsum(x, dy, dims)) # `_unsum` handles Ref internally
1618
)
17-
return (NoTangent(), )
19+
return (NoTangent(), x_thunk)
1820
end
1921
return y, sum_pullback
2022
end
2123

24+
# This broadcasts `dy` to the shape of `x`, and should preserve e.g. CuArrays, StaticArrays.
25+
# Ideally this would only need `typeof(x)` not `x`, but `similar` only has a suitable method
26+
# when `eltype(x) == eltype(dy)`, which isn't guaranteed.
27+
_unsum(x, dy, dims) = broadcast(lasttuple, x, dy)
28+
_unsum(x, dy, ::Colon) = broadcast(lasttuple, x, Ref(dy))
29+
30+
# Allow for second derivatives of `sum`, by writing rules for `_unsum`:
31+
32+
function frule((_, _, dydot, _), ::typeof(_unsum), x, dy, dims)
33+
return _unsum(x, dy, dims), _unsum(x, dydot, dims)
34+
end
35+
36+
function rrule(::typeof(_unsum), x, dy, dims)
37+
z = _unsum(x, dy, dims)
38+
_unsum_pullback(dz) = (NoTangent(), NoTangent(), sum(unthunk(dz); dims=dims), NoTangent())
39+
return z, _unsum_pullback
40+
end
41+
42+
#####
43+
##### `sum(f, x)`
44+
#####
45+
2246
# Can't map over Adjoint/Transpose Vector
2347
function rrule(
2448
config::RuleConfig{>:HasReverseMode},

src/rulesets/Base/nondiff.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,9 @@
7676
@non_differentiable similar(::AbstractArray{Bool}, ::Any...)
7777
@non_differentiable stride(::AbstractArray{Bool}, ::Any)
7878
@non_differentiable strides(::AbstractArray{Bool})
79+
@non_differentiable sum(::AbstractArray{Bool})
80+
@non_differentiable sum(::Any, ::AbstractArray{Bool})
81+
@non_differentiable sum(::typeof(abs2), ::AbstractArray{Bool}) # avoids an ambiguity
7982
@non_differentiable vcat(::AbstractArray{Bool}...)
8083
@non_differentiable vec(::AbstractArray{Bool})
8184
@non_differentiable Vector(::AbstractArray{Bool})

test/rulesets/Base/mapreduce.jl

Lines changed: 47 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,36 @@
11
@testset "Maps and Reductions" begin
2-
@testset "sum" begin
3-
sizes = (3, 4, 7)
4-
@testset "dims = $dims" for dims in (:, 1)
5-
@testset "Array{$N, $T}" for N in eachindex(sizes), T in (Float64, ComplexF64)
6-
x = randn(T, sizes[1:N]...)
7-
test_frule(sum, x; fkwargs=(;dims=dims))
8-
test_rrule(sum, x; fkwargs=(;dims=dims))
9-
end
2+
@testset "sum(x; dims=$dims)" for dims in (:, 2, (1,3))
3+
# Forward
4+
test_frule(sum, rand(5); fkwargs=(;dims=dims))
5+
test_frule(sum, rand(ComplexF64, 2,3,4); fkwargs=(;dims=dims))
6+
7+
# Reverse
8+
test_rrule(sum, rand(5); fkwargs=(;dims=dims))
9+
test_rrule(sum, rand(ComplexF64, 2,3,4); fkwargs=(;dims=dims))
10+
11+
# Structured matrices
12+
test_rrule(sum, rand(5)'; fkwargs=(;dims=dims))
13+
y, back = rrule(sum, UpperTriangular(rand(5,5)); dims=dims)
14+
unthunk(back(y*(1+im))[2]) isa UpperTriangular{Float64}
15+
@test_skip test_rrule(sum, UpperTriangular(rand(5,5)) randn(5,5); fkwargs=(;dims=dims), check_inferred=false) # Problem: in add!! Evaluated: isapprox
16+
17+
# Boolean -- via @non_differentiable
18+
test_rrule(sum, randn(5) .> 0; fkwargs=(;dims=dims))
19+
20+
# Function allowing for 2nd derivatives
21+
for x in (rand(5), rand(2,3,4))
22+
dy = maximum(x; dims=dims)
23+
test_frule(ChainRules._unsum, x, dy, dims)
24+
test_rrule(ChainRules._unsum, x, dy, dims)
25+
end
26+
27+
# Arrays of arrays
28+
for x in ([rand(ComplexF64, 3) for _ in 1:4], [rand(3) for _ in 1:2, _ in 1:3, _ in 1:4])
29+
test_rrule(sum, x; fkwargs=(;dims=dims), check_inferred=false)
30+
31+
dy = sum(x; dims=dims)
32+
ddy = rrule(ChainRules._unsum, x, dy, dims)[2](x)[3]
33+
@test size(ddy) == size(dy)
1034
end
1135
end
1236

@@ -18,20 +42,29 @@
1842
test_frule(sum, abs2, x; fkwargs=(;dims=dims))
1943
test_rrule(sum, abs2, x; fkwargs=(;dims=dims))
2044
end
45+
46+
# Boolean -- via @non_differentiable, test that this isn't ambiguous
47+
test_rrule(sum, abs2, randn(5) .> 0; fkwargs=(;dims=dims))
2148
end
2249
end # sum abs2
2350

2451
@testset "sum(f, xs)" begin
2552
# This calls back into AD
2653
test_rrule(sum, abs, [-4.0, 2.0, 2.0])
54+
test_rrule(sum, cbrt, randn(5))
2755
test_rrule(sum, Multiplier(2.0), [2.0, 4.0, 8.0])
2856

57+
# Complex numbers
58+
test_rrule(sum, sqrt, rand(ComplexF64, 5))
59+
test_rrule(sum, abs, rand(ComplexF64, 3, 4)) # complex -> real
60+
2961
# inference fails for array of arrays
3062
test_rrule(sum, sum, [[2.0, 4.0], [4.0,1.9]]; check_inferred=false)
3163

3264
# dims kwarg
3365
test_rrule(sum, abs, [-2.0 4.0; 5.0 1.9]; fkwargs=(;dims=1))
3466
test_rrule(sum, abs, [-2.0 4.0; 5.0 1.9]; fkwargs=(;dims=2))
67+
test_rrule(sum, sqrt, rand(ComplexF64, 3, 4); fkwargs=(;dims=(1,)))
3568

3669
test_rrule(sum, abs, @SVector[1.0, -3.0])
3770

@@ -49,6 +82,12 @@
4982
# make sure we preserve type for Diagonal
5083
_, pb = rrule(ADviaRuleConfig(), sum, abs, Diagonal([1.0, -3.0]))
5184
@test pb(1.0)[3] isa Diagonal
85+
86+
# Boolean -- via @non_differentiable, test that this isn't ambiguous
87+
test_rrule(sum, sqrt, randn(5) .> 0)
88+
test_rrule(sum, sqrt, randn(5,5) .> 0; fkwargs=(;dims=1))
89+
# ... and Bool produced by function
90+
@test_skip test_rrule(sum, iszero, randn(5)) # DimensionMismatch("second dimension of A, 1, does not match length of x, 0")
5291
end
5392

5493
@testset "prod" begin
@@ -100,7 +139,6 @@
100139
@test unthunk(rrule(prod, UpperTriangular(ones(T,2,2)))[2](1.0)[2]) == UpperTriangular([0.0 0; 1 0])
101140

102141
# Symmetric -- at least this doesn't have zeros, still an unlikely combination
103-
104142
xs = Symmetric(rand(T,4,4))
105143
@test unthunk(rrule(prod, Symmetric(T[1 2; -333 4]))[2](1.0)[2]) == [16 8; 8 4]
106144
# TODO debug why these fail https://github.com/JuliaDiff/ChainRules.jl/issues/475

0 commit comments

Comments
 (0)