Skip to content

Commit 25ca9ee

Browse files
author
Miha Zgubic
committed
fix tests
1 parent 5a68a61 commit 25ca9ee

File tree

2 files changed

+10
-2
lines changed

2 files changed

+10
-2
lines changed

src/rulesets/Base/mapreduce.jl

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,15 @@ function rrule(
8989
end
9090

9191
# https://github.com/JuliaDiff/ChainRules.jl/issues/522
92-
@opt_out ChainRulesCore.rrule(config::RuleConfig{>:HasReverseMode}, ::typeof(sum), x::AbstractArray, y::AbstractArray; dims=:)
92+
# The rule above assumes `f` is callable. Arrays are not, this came up when summing
93+
# arrays with weights in StatsBase
94+
@opt_out ChainRulesCore.rrule(
95+
config::RuleConfig{>:HasReverseMode},
96+
::typeof(sum),
97+
x::AbstractArray,
98+
y::AbstractArray;
99+
dims=:
100+
)
93101

94102
function frule(
95103
(_, _, Δx),

test/rulesets/Base/mapreduce.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@
9191
end
9292

9393
# https://github.com/JuliaDiff/ChainRules.jl/issues/522
94-
@begin "sum(xs, weights) (#522)" begin
94+
@testset "sum(xs, weights) (#522)" begin
9595
xs = rand(5)
9696
weights = rand(5)
9797
Base.sum(xs::AbstractArray, weights::AbstractArray) = dot(xs, weights)

0 commit comments

Comments
 (0)