Skip to content

Commit f17dcdc

Browse files
authored
Merge pull request #124 from JuliaDiff/ox/notuple
Don't allow sum of anything in rrule
2 parents e7ad852 + 8fbe739 commit f17dcdc

File tree

2 files changed

+4
-4
lines changed

2 files changed

+4
-4
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ChainRules"
22
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
3-
version = "0.2.2-DEV"
3+
version = "0.2.3"
44

55
[deps]
66
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

src/rulesets/Base/mapreduce.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ function frule(::typeof(sum), x)
5757
return sum(x), sum_pushforward
5858
end
5959

60-
function rrule(::typeof(sum), x)
60+
function rrule(::typeof(sum), x::AbstractArray{<:Real})
6161
function sum_pullback(ȳ)
6262
return (NO_FIELDS, @thunk(fill(ȳ, size(x))))
6363
end
@@ -67,15 +67,15 @@ end
6767
function rrule(::typeof(sum), f, x::AbstractArray{<:Real}; dims=:)
6868
y, mr_pullback = rrule(mapreduce, f, Base.add_sum, x; dims=dims)
6969
function sum_pullback(ȳ)
70-
NO_FIELDS, DNE(), last(mr_pullback(ȳ))
70+
return NO_FIELDS, DNE(), last(mr_pullback(ȳ))
7171
end
7272
return y, sum_pullback
7373
end
7474

7575
function rrule(::typeof(sum), x::AbstractArray{<:Real}; dims=:)
7676
y, inner_pullback = rrule(sum, identity, x; dims=dims)
7777
function sum_pullback(ȳ)
78-
NO_FIELDS, last(inner_pullback(ȳ))
78+
return NO_FIELDS, last(inner_pullback(ȳ))
7979
end
8080
return y, sum_pullback
8181
end

0 commit comments

Comments
 (0)