Skip to content

Commit 5a68a61

Browse files
author
Miha Zgubic
committed
@opt_out of sum(array, array)
1 parent a130b8f commit 5a68a61

File tree

3 files changed

+14
-1
lines changed

3 files changed

+14
-1
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ChainRules"
22
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
3-
version = "1.11.3"
3+
version = "1.11.4"
44

55
[deps]
66
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

src/rulesets/Base/mapreduce.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,9 @@ function rrule(
8888
return y, sum_pullback
8989
end
9090

91+
# https://github.com/JuliaDiff/ChainRules.jl/issues/522
92+
@opt_out ChainRulesCore.rrule(config::RuleConfig{>:HasReverseMode}, ::typeof(sum), x::AbstractArray, y::AbstractArray; dims=:)
93+
9194
function frule(
9295
(_, _, Δx),
9396
::typeof(sum),

test/rulesets/Base/mapreduce.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,16 @@
9090
@test_skip test_rrule(sum, iszero, randn(5)) # DimensionMismatch("second dimension of A, 1, does not match length of x, 0")
9191
end
9292

93+
# https://github.com/JuliaDiff/ChainRules.jl/issues/522
94+
@begin "sum(xs, weights) (#522)" begin
95+
xs = rand(5)
96+
weights = rand(5)
97+
Base.sum(xs::AbstractArray, weights::AbstractArray) = dot(xs, weights)
98+
struct MyRuleConfig <: RuleConfig{Union{HasReverseMode}} end
99+
100+
@test rrule(MyRuleConfig(), Base.sum, xs, weights) isa Nothing
101+
end
102+
93103
@testset "prod" begin
94104
@testset "Array{$T}" for T in [Float64, ComplexF64]
95105
@testset "size = $sz, dims = $dims" for (sz, dims) in [

0 commit comments

Comments
 (0)