Skip to content

Commit 6bf0f30

Browse files
authored
Merge pull request #524 from JuliaDiff/mz/optout
`@opt_out of sum(array, array)`
2 parents c637457 + ab07ced commit 6bf0f30

File tree

3 files changed

+24
-1
lines changed

3 files changed

+24
-1
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ChainRules"
22
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
3-
version = "1.11.4"
3+
version = "1.11.5"
44

55
[deps]
66
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

src/rulesets/Base/mapreduce.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,17 @@ function rrule(
8888
return y, sum_pullback
8989
end
9090

91+
# https://github.com/JuliaDiff/ChainRules.jl/issues/522
92+
# The rule above assumes `f` is callable. Arrays are not, this came up when summing
93+
# arrays with weights in StatsBase
94+
@opt_out ChainRulesCore.rrule(
95+
config::RuleConfig{>:HasReverseMode},
96+
::typeof(sum),
97+
x::AbstractArray,
98+
y::AbstractArray;
99+
dims=:
100+
)
101+
91102
function frule(
92103
(_, _, Δx),
93104
::typeof(sum),

test/rulesets/Base/mapreduce.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
# for sum(xs, weights) (#522)
2+
Base.sum(xs::AbstractArray, weights::AbstractArray) = dot(xs, weights)
3+
struct SumRuleConfig <: RuleConfig{Union{HasReverseMode}} end
4+
15
@testset "Maps and Reductions" begin
26
@testset "sum(x; dims=$dims)" for dims in (:, 2, (1,3))
37
# Forward
@@ -90,6 +94,14 @@
9094
@test_skip test_rrule(sum, iszero, randn(5)) # DimensionMismatch("second dimension of A, 1, does not match length of x, 0")
9195
end
9296

97+
# https://github.com/JuliaDiff/ChainRules.jl/issues/522
98+
@testset "sum(xs, weights) (#522)" begin
99+
xs = rand(5)
100+
weights = rand(5)
101+
102+
@test rrule(SumRuleConfig(), Base.sum, xs, weights) isa Nothing
103+
end
104+
93105
@testset "prod" begin
94106
@testset "Array{$T}" for T in [Float64, ComplexF64]
95107
@testset "size = $sz, dims = $dims" for (sz, dims) in [

0 commit comments

Comments
 (0)