Skip to content

Commit 2ff02b6

Browse files
committed
fix tests, add a few
1 parent 56e6a31 commit 2ff02b6

File tree

1 file changed

+13
-4
lines changed

1 file changed

+13
-4
lines changed

test/rulesets/Base/mapreduce.jl

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,23 +42,29 @@
4242
test_frule(sum, abs2, x; fkwargs=(;dims=dims))
4343
test_rrule(sum, abs2, x; fkwargs=(;dims=dims))
4444
end
45-
end
4645

47-
# Boolean -- via @non_differentiable, test that this isn't ambiguous
48-
test_rrule(sum, abs2, randn(5) .> 0; fkwargs=(;dims=dims))
46+
# Boolean -- via @non_differentiable, test that this isn't ambiguous
47+
test_rrule(sum, abs2, randn(5) .> 0; fkwargs=(;dims=dims))
48+
end
4949
end # sum abs2
5050

5151
@testset "sum(f, xs)" begin
5252
# This calls back into AD
5353
test_rrule(sum, abs, [-4.0, 2.0, 2.0])
54+
test_rrule(sum, cbrt, randn(5))
5455
test_rrule(sum, Multiplier(2.0), [2.0, 4.0, 8.0])
5556

57+
# Complex numbers
58+
test_rrule(sum, sqrt, rand(ComplexF64, 5))
59+
test_rrule(sum, abs, rand(ComplexF64, 3, 4)) # complex -> real
60+
5661
# inference fails for array of arrays
5762
test_rrule(sum, sum, [[2.0, 4.0], [4.0,1.9]]; check_inferred=false)
5863

5964
# dims kwarg
6065
test_rrule(sum, abs, [-2.0 4.0; 5.0 1.9]; fkwargs=(;dims=1))
6166
test_rrule(sum, abs, [-2.0 4.0; 5.0 1.9]; fkwargs=(;dims=2))
67+
test_rrule(sum, sqrt, rand(ComplexF64, 3, 4); fkwargs=(;dims=(1,)))
6268

6369
test_rrule(sum, abs, @SVector[1.0, -3.0])
6470

@@ -78,7 +84,10 @@
7884
@test pb(1.0)[3] isa Diagonal
7985

8086
# Boolean -- via @non_differentiable, test that this isn't ambiguous
81-
@test_skip test_rrule(sum, sqrt, randn(5) .> 0; fkwargs=(;dims=dims)) # MethodError: no method matching real(::NoTangent)
87+
test_rrule(sum, sqrt, randn(5) .> 0)
88+
test_rrule(sum, sqrt, randn(5,5) .> 0; fkwargs=(;dims=1))
89+
# ... and Bool produced by function
90+
@test_skip test_rrule(sum, iszero, randn(5)) # DimensionMismatch("second dimension of A, 1, does not match length of x, 0")
8291
end
8392

8493
@testset "prod" begin

0 commit comments

Comments
 (0)