Skip to content

Commit 987ee45

Browse files
mcabbottMichael Abbottoxinaboxdevmotion
authored
Improve rules for sum (#336)
* tweaks to sum rules * Apply suggestions from code review Co-authored-by: Lyndon White <oxinabox@ucc.asn.au> * comma * v0.7.62 * gone cold one more time... Co-authored-by: David Widmann <devmotion@users.noreply.github.com> Co-authored-by: Michael Abbott <me@escbook> Co-authored-by: Lyndon White <oxinabox@ucc.asn.au> Co-authored-by: David Widmann <devmotion@users.noreply.github.com>
1 parent daff86a commit 987ee45

File tree

2 files changed

+13
-6
lines changed

2 files changed

+13
-6
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ChainRules"
22
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
3-
version = "0.7.62"
3+
version = "0.7.63"
44

55
[deps]
66
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

src/rulesets/Base/mapreduce.jl

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,10 @@ function rrule(::typeof(sum), x::AbstractArray{T}; dims=:) where {T<:Number}
1010
y = sum(x; dims=dims)
1111
function sum_pullback(ȳ)
1212
# broadcasting the two works out the size no-matter `dims`
13-
= broadcast(x, ȳ) do xi, ȳi
14-
ȳi
15-
end
13+
= InplaceableThunk(
14+
@thunk(broadcast(lasttuple, x, ȳ)),
15+
x -> x .+=
16+
)
1617
return (NO_FIELDS, x̄)
1718
end
1819
return y, sum_pullback
@@ -29,7 +30,9 @@ function frule(
2930
∂y = if dims isa Colon
3031
2 * real(dot(x, ẋ))
3132
elseif VERSION v"1.2" # multi-iterator mapreduce introduced in v1.2
32-
2 * mapreduce(_realconjtimes, +, x, ẋ; dims=dims)
33+
mapreduce(+, x, ẋ; dims=dims) do xi, dxi
34+
2 * _realconjtimes(xi, dxi)
35+
end
3336
else
3437
2 * sum(_realconjtimes.(x, ẋ); dims=dims)
3538
end
@@ -44,7 +47,11 @@ function rrule(
4447
) where {T<:Union{Real,Complex}}
4548
y = sum(abs2, x; dims=dims)
4649
function sum_abs2_pullback(ȳ)
47-
return (NO_FIELDS, DoesNotExist(), 2 .* real.(ȳ) .* x)
50+
x_thunk = InplaceableThunk(
51+
@thunk(2 .* real.(ȳ) .* x),
52+
dx -> dx .+= 2 .* real.(ȳ) .* x
53+
)
54+
return (NO_FIELDS, DoesNotExist(), x_thunk)
4855
end
4956
return y, sum_abs2_pullback
5057
end

0 commit comments

Comments
 (0)