Skip to content

Commit 7f8b54a

Browse files
committed
sum which allows 2nd derivatives
1 parent 7593339 commit 7f8b54a

File tree

2 files changed

+60
-16
lines changed

2 files changed

+60
-16
lines changed

src/rulesets/Base/mapreduce.jl

Lines changed: 32 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,48 @@
11
#####
2-
##### `sum`
2+
##### `sum(x)`
33
#####
44

55
function frule((_, ẋ), ::typeof(sum), x; dims=:)
66
return sum(x; dims=dims), sum(ẋ; dims=dims)
77
end
88

9-
function rrule(::typeof(sum), x::AbstractArray{T}; dims=:) where {T<:Number}
9+
function rrule(::typeof(sum), x::AbstractArray; dims=:)
10+
project = ProjectTo(x)
1011
y = sum(x; dims=dims)
11-
function sum_pullback(ȳ)
12-
# broadcasting the two works out the size no-matter `dims`
13-
= InplaceableThunk(
14-
x -> x .+= ȳ,
15-
@thunk(broadcast(lasttuple, x, ȳ)),
12+
function sum_pullback(dy_raw)
13+
dy = unthunk(dy_raw)
14+
x_thunk = InplaceableThunk(
15+
# Protect `dy` from broadcasting, for when `x` is an array of arrays:
16+
dx -> dx .+= (dims isa Colon ? Ref(dy) : dy),
17+
@thunk project(_unsum(x, dy, dims)) # `_unsum` handles Ref internally
1618
)
17-
return (NoTangent(), )
19+
return (NoTangent(), x_thunk)
1820
end
1921
return y, sum_pullback
2022
end
2123

24+
# This broadcasts `dy` to the shape of `x`, and should preserve e.g. CuArrays, StaticArrays.
25+
# Ideally this would only need `typeof(x)` not `x`, but `similar` only has a suitable method
26+
# when `eltype(x) == eltype(dy)`, which isn't guaranteed.
27+
_unsum(x, dy, dims) = broadcast(lasttuple, x, dy)
28+
_unsum(x, dy, ::Colon) = broadcast(lasttuple, x, Ref(dy))
29+
30+
# Allow for second derivatives of `sum`, by writing rules for `_unsum`:
31+
32+
function frule((_, _, dydot, _), ::typeof(_unsum), x, dy, dims)
33+
return _unsum(x, dy, dims), _unsum(x, dydot, dims)
34+
end
35+
36+
function rrule(::typeof(_unsum), x, dy, dims)
37+
z = _unsum(x, dy, dims)
38+
_unsum_pullback(dz) = (NoTangent(), NoTangent(), sum(unthunk(dz); dims=dims), NoTangent())
39+
return z, _unsum_pullback
40+
end
41+
42+
#####
43+
##### `sum(f, x)`
44+
#####
45+
2246
# Can't map over Adjoint/Transpose Vector
2347
function rrule(
2448
config::RuleConfig{>:HasReverseMode},

test/rulesets/Base/mapreduce.jl

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,32 @@
11
@testset "Maps and Reductions" begin
2-
@testset "sum" begin
3-
sizes = (3, 4, 7)
4-
@testset "dims = $dims" for dims in (:, 1)
5-
@testset "Array{$N, $T}" for N in eachindex(sizes), T in (Float64, ComplexF64)
6-
x = randn(T, sizes[1:N]...)
7-
test_frule(sum, x; fkwargs=(;dims=dims))
8-
test_rrule(sum, x; fkwargs=(;dims=dims))
9-
end
2+
@testset "sum(x; dims=$dims)" for dims in (:, 2, (1,3))
3+
# Forward
4+
test_frule(sum, rand(5); fkwargs=(;dims=dims))
5+
test_frule(sum, rand(ComplexF64, 2,3,4); fkwargs=(;dims=dims))
6+
7+
# Reverse
8+
test_rrule(sum, rand(5); fkwargs=(;dims=dims))
9+
test_rrule(sum, rand(ComplexF64, 2,3,4); fkwargs=(;dims=dims))
10+
11+
# Structured matrices
12+
test_rrule(sum, rand(5)'; fkwargs=(;dims=dims))
13+
y, back = rrule(sum, UpperTriangular(rand(5,5)); dims=dims)
14+
unthunk(back(y*(1+im))[2]) isa UpperTriangular{Float64}
15+
16+
# Function allowing for 2nd derivatives
17+
for x in (rand(5), rand(2,3,4))
18+
dy = maximum(x; dims=dims)
19+
test_frule(ChainRules._unsum, x, dy, dims)
20+
test_rrule(ChainRules._unsum, x, dy, dims)
21+
end
22+
23+
# Arrays of arrays
24+
for x in ([rand(ComplexF64, 3) for _ in 1:4], [rand(3) for _ in 1:2, _ in 1:3, _ in 1:4])
25+
test_rrule(sum, x; fkwargs=(;dims=dims), check_inferred=false)
26+
27+
dy = sum(x; dims=dims)
28+
ddy = rrule(ChainRules._unsum, x, dy, dims)[2](x)[3]
29+
@test size(ddy) == size(dy)
1030
end
1131
end
1232

0 commit comments

Comments
 (0)