|
1 | 1 | #####
|
2 |
| -##### `sum` |
| 2 | +##### `sum(x)` |
3 | 3 | #####
|
4 | 4 |
|
5 | 5 | function frule((_, ẋ), ::typeof(sum), x; dims=:)
|
6 | 6 | return sum(x; dims=dims), sum(ẋ; dims=dims)
|
7 | 7 | end
|
8 | 8 |
|
9 |
| -function rrule(::typeof(sum), x::AbstractArray{T}; dims=:) where {T<:Number} |
| 9 | +function rrule(::typeof(sum), x::AbstractArray; dims=:) |
| 10 | + project = ProjectTo(x) |
10 | 11 | y = sum(x; dims=dims)
|
11 |
| - function sum_pullback(ȳ) |
12 |
| - # broadcasting the two works out the size no-matter `dims` |
13 |
| - x̄ = InplaceableThunk( |
14 |
| - x -> x .+= ȳ, |
15 |
| - @thunk(broadcast(last∘tuple, x, ȳ)), |
| 12 | + function sum_pullback(dy_raw) |
| 13 | + dy = unthunk(dy_raw) |
| 14 | + x_thunk = InplaceableThunk( |
| 15 | + # Protect `dy` from broadcasting, for when `x` is an array of arrays: |
| 16 | + dx -> dx .+= (dims isa Colon ? Ref(dy) : dy), |
| 17 | + @thunk project(_unsum(x, dy, dims)) # `_unsum` handles Ref internally |
16 | 18 | )
|
17 |
| - return (NoTangent(), x̄) |
| 19 | + return (NoTangent(), x_thunk) |
18 | 20 | end
|
19 | 21 | return y, sum_pullback
|
20 | 22 | end
|
21 | 23 |
|
| 24 | +# This broadcasts `dy` to the shape of `x`, and should preserve e.g. CuArrays, StaticArrays. |
| 25 | +# Ideally this would only need `typeof(x)` not `x`, but `similar` only has a suitable method |
| 26 | +# when `eltype(x) == eltype(dy)`, which isn't guaranteed. |
| 27 | +_unsum(x, dy, dims) = broadcast(last∘tuple, x, dy) |
| 28 | +_unsum(x, dy, ::Colon) = broadcast(last∘tuple, x, Ref(dy)) |
| 29 | + |
| 30 | +# Allow for second derivatives of `sum`, by writing rules for `_unsum`: |
| 31 | + |
| 32 | +function frule((_, _, dydot, _), ::typeof(_unsum), x, dy, dims) |
| 33 | + return _unsum(x, dy, dims), _unsum(x, dydot, dims) |
| 34 | +end |
| 35 | + |
| 36 | +function rrule(::typeof(_unsum), x, dy, dims) |
| 37 | + z = _unsum(x, dy, dims) |
| 38 | + _unsum_pullback(dz) = (NoTangent(), NoTangent(), sum(unthunk(dz); dims=dims), NoTangent()) |
| 39 | + return z, _unsum_pullback |
| 40 | +end |
| 41 | + |
| 42 | +##### |
| 43 | +##### `sum(f, x)` |
| 44 | +##### |
| 45 | + |
22 | 46 | # Can't map over Adjoint/Transpose Vector
|
23 | 47 | function rrule(
|
24 | 48 | config::RuleConfig{>:HasReverseMode},
|
|
0 commit comments