Skip to content

Commit 56e6a31

Browse files
committed
add some sum(::Array{Bool}) rules
1 parent 7f8b54a commit 56e6a31

File tree

2 files changed

+13
-1
lines changed

2 files changed

+13
-1
lines changed

src/rulesets/Base/nondiff.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,9 @@
7676
@non_differentiable similar(::AbstractArray{Bool}, ::Any...)
7777
@non_differentiable stride(::AbstractArray{Bool}, ::Any)
7878
@non_differentiable strides(::AbstractArray{Bool})
79+
@non_differentiable sum(::AbstractArray{Bool})
80+
@non_differentiable sum(::Any, ::AbstractArray{Bool})
81+
@non_differentiable sum(::typeof(abs2), ::AbstractArray{Bool}) # avoids an ambiguity
7982
@non_differentiable vcat(::AbstractArray{Bool}...)
8083
@non_differentiable vec(::AbstractArray{Bool})
8184
@non_differentiable Vector(::AbstractArray{Bool})

test/rulesets/Base/mapreduce.jl

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,11 @@
1212
test_rrule(sum, rand(5)'; fkwargs=(;dims=dims))
1313
y, back = rrule(sum, UpperTriangular(rand(5,5)); dims=dims)
1414
unthunk(back(y*(1+im))[2]) isa UpperTriangular{Float64}
15+
@test_skip test_rrule(sum, UpperTriangular(rand(5,5)) randn(5,5); fkwargs=(;dims=dims), check_inferred=false) # Problem: in add!! Evaluated: isapprox
1516

17+
# Boolean -- via @non_differentiable
18+
test_rrule(sum, randn(5) .> 0; fkwargs=(;dims=dims))
19+
1620
# Function allowing for 2nd derivatives
1721
for x in (rand(5), rand(2,3,4))
1822
dy = maximum(x; dims=dims)
@@ -39,6 +43,9 @@
3943
test_rrule(sum, abs2, x; fkwargs=(;dims=dims))
4044
end
4145
end
46+
47+
# Boolean -- via @non_differentiable, test that this isn't ambiguous
48+
test_rrule(sum, abs2, randn(5) .> 0; fkwargs=(;dims=dims))
4249
end # sum abs2
4350

4451
@testset "sum(f, xs)" begin
@@ -69,6 +76,9 @@
6976
# make sure we preserve type for Diagonal
7077
_, pb = rrule(ADviaRuleConfig(), sum, abs, Diagonal([1.0, -3.0]))
7178
@test pb(1.0)[3] isa Diagonal
79+
80+
# Boolean -- via @non_differentiable, test that this isn't ambiguous
81+
@test_skip test_rrule(sum, sqrt, randn(5) .> 0; fkwargs=(;dims=dims)) # MethodError: no method matching real(::NoTangent)
7282
end
7383

7484
@testset "prod" begin
@@ -120,7 +130,6 @@
120130
@test unthunk(rrule(prod, UpperTriangular(ones(T,2,2)))[2](1.0)[2]) == UpperTriangular([0.0 0; 1 0])
121131

122132
# Symmetric -- at least this doesn't have zeros, still an unlikely combination
123-
124133
xs = Symmetric(rand(T,4,4))
125134
@test unthunk(rrule(prod, Symmetric(T[1 2; -333 4]))[2](1.0)[2]) == [16 8; 8 4]
126135
# TODO debug why these fail https://github.com/JuliaDiff/ChainRules.jl/issues/475

0 commit comments

Comments
 (0)