Skip to content

Commit 0ff02a1

Browse files
committed
Don't allow sum of anything in rrule
1 parent 9e4cb76 commit 0ff02a1

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

src/rulesets/Base/mapreduce.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ function frule(::typeof(sum), x)
5757
return sum(x), sum_pushforward
5858
end
5959

60-
function rrule(::typeof(sum), x)
60+
function rrule(::typeof(sum), x::AbstractArray{<:Real})
6161
function sum_pullback(ȳ)
6262
return (NO_FIELDS, @thunk(fill(ȳ, size(x))))
6363
end
@@ -67,15 +67,15 @@ end
6767
function rrule(::typeof(sum), f, x::AbstractArray{<:Real}; dims=:)
6868
y, mr_pullback = rrule(mapreduce, f, Base.add_sum, x; dims=dims)
6969
function sum_pullback(ȳ)
70-
NO_FIELDS, DNE(), last(mr_pullback(ȳ))
70+
return NO_FIELDS, DNE(), last(mr_pullback(ȳ))
7171
end
7272
return y, sum_pullback
7373
end
7474

7575
function rrule(::typeof(sum), x::AbstractArray{<:Real}; dims=:)
7676
y, inner_pullback = rrule(sum, identity, x; dims=dims)
7777
function sum_pullback(ȳ)
78-
NO_FIELDS, last(inner_pullback(ȳ))
78+
return NO_FIELDS, last(inner_pullback(ȳ))
7979
end
8080
return y, sum_pullback
8181
end

0 commit comments

Comments
 (0)