|
1 | 1 | @testset "Maps and Reductions" begin
|
2 |
| - @testset "sum" begin |
3 |
| - sizes = (3, 4, 7) |
4 |
| - @testset "dims = $dims" for dims in (:, 1) |
5 |
| - @testset "Array{$N, $T}" for N in eachindex(sizes), T in (Float64, ComplexF64) |
6 |
| - x = randn(T, sizes[1:N]...) |
7 |
| - test_frule(sum, x; fkwargs=(;dims=dims)) |
8 |
| - test_rrule(sum, x; fkwargs=(;dims=dims)) |
9 |
| - end |
| 2 | + @testset "sum(x; dims=$dims)" for dims in (:, 2, (1,3)) |
| 3 | + # Forward |
| 4 | + test_frule(sum, rand(5); fkwargs=(;dims=dims)) |
| 5 | + test_frule(sum, rand(ComplexF64, 2,3,4); fkwargs=(;dims=dims)) |
| 6 | + |
| 7 | + # Reverse |
| 8 | + test_rrule(sum, rand(5); fkwargs=(;dims=dims)) |
| 9 | + test_rrule(sum, rand(ComplexF64, 2,3,4); fkwargs=(;dims=dims)) |
| 10 | + |
| 11 | + # Structured matrices |
| 12 | + test_rrule(sum, rand(5)'; fkwargs=(;dims=dims)) |
| 13 | + y, back = rrule(sum, UpperTriangular(rand(5,5)); dims=dims) |
| 14 | + unthunk(back(y*(1+im))[2]) isa UpperTriangular{Float64} |
| 15 | + @test_skip test_rrule(sum, UpperTriangular(rand(5,5)) ⊢ randn(5,5); fkwargs=(;dims=dims), check_inferred=false) # Problem: in add!! Evaluated: isapprox |
| 16 | + |
| 17 | + # Boolean -- via @non_differentiable |
| 18 | + test_rrule(sum, randn(5) .> 0; fkwargs=(;dims=dims)) |
| 19 | + |
| 20 | + # Function allowing for 2nd derivatives |
| 21 | + for x in (rand(5), rand(2,3,4)) |
| 22 | + dy = maximum(x; dims=dims) |
| 23 | + test_frule(ChainRules._unsum, x, dy, dims) |
| 24 | + test_rrule(ChainRules._unsum, x, dy, dims) |
| 25 | + end |
| 26 | + |
| 27 | + # Arrays of arrays |
| 28 | + for x in ([rand(ComplexF64, 3) for _ in 1:4], [rand(3) for _ in 1:2, _ in 1:3, _ in 1:4]) |
| 29 | + test_rrule(sum, x; fkwargs=(;dims=dims), check_inferred=false) |
| 30 | + |
| 31 | + dy = sum(x; dims=dims) |
| 32 | + ddy = rrule(ChainRules._unsum, x, dy, dims)[2](x)[3] |
| 33 | + @test size(ddy) == size(dy) |
10 | 34 | end
|
11 | 35 | end
|
12 | 36 |
|
|
18 | 42 | test_frule(sum, abs2, x; fkwargs=(;dims=dims))
|
19 | 43 | test_rrule(sum, abs2, x; fkwargs=(;dims=dims))
|
20 | 44 | end
|
| 45 | + |
| 46 | + # Boolean -- via @non_differentiable, test that this isn't ambiguous |
| 47 | + test_rrule(sum, abs2, randn(5) .> 0; fkwargs=(;dims=dims)) |
21 | 48 | end
|
22 | 49 | end # sum abs2
|
23 | 50 |
|
24 | 51 | @testset "sum(f, xs)" begin
|
25 | 52 | # This calls back into AD
|
26 | 53 | test_rrule(sum, abs, [-4.0, 2.0, 2.0])
|
| 54 | + test_rrule(sum, cbrt, randn(5)) |
27 | 55 | test_rrule(sum, Multiplier(2.0), [2.0, 4.0, 8.0])
|
28 | 56 |
|
| 57 | + # Complex numbers |
| 58 | + test_rrule(sum, sqrt, rand(ComplexF64, 5)) |
| 59 | + test_rrule(sum, abs, rand(ComplexF64, 3, 4)) # complex -> real |
| 60 | + |
29 | 61 | # inference fails for array of arrays
|
30 | 62 | test_rrule(sum, sum, [[2.0, 4.0], [4.0,1.9]]; check_inferred=false)
|
31 | 63 |
|
32 | 64 | # dims kwarg
|
33 | 65 | test_rrule(sum, abs, [-2.0 4.0; 5.0 1.9]; fkwargs=(;dims=1))
|
34 | 66 | test_rrule(sum, abs, [-2.0 4.0; 5.0 1.9]; fkwargs=(;dims=2))
|
| 67 | + test_rrule(sum, sqrt, rand(ComplexF64, 3, 4); fkwargs=(;dims=(1,))) |
35 | 68 |
|
36 | 69 | test_rrule(sum, abs, @SVector[1.0, -3.0])
|
37 | 70 |
|
|
49 | 82 | # make sure we preserve type for Diagonal
|
50 | 83 | _, pb = rrule(ADviaRuleConfig(), sum, abs, Diagonal([1.0, -3.0]))
|
51 | 84 | @test pb(1.0)[3] isa Diagonal
|
| 85 | + |
| 86 | + # Boolean -- via @non_differentiable, test that this isn't ambiguous |
| 87 | + test_rrule(sum, sqrt, randn(5) .> 0) |
| 88 | + test_rrule(sum, sqrt, randn(5,5) .> 0; fkwargs=(;dims=1)) |
| 89 | + # ... and Bool produced by function |
| 90 | + @test_skip test_rrule(sum, iszero, randn(5)) # DimensionMismatch("second dimension of A, 1, does not match length of x, 0") |
52 | 91 | end
|
53 | 92 |
|
54 | 93 | @testset "prod" begin
|
|
100 | 139 | @test unthunk(rrule(prod, UpperTriangular(ones(T,2,2)))[2](1.0)[2]) == UpperTriangular([0.0 0; 1 0])
|
101 | 140 |
|
102 | 141 | # Symmetric -- at least this doesn't have zeros, still an unlikely combination
|
103 |
| - |
104 | 142 | xs = Symmetric(rand(T,4,4))
|
105 | 143 | @test unthunk(rrule(prod, Symmetric(T[1 2; -333 4]))[2](1.0)[2]) == [16 8; 8 4]
|
106 | 144 | # TODO debug why these fail https://github.com/JuliaDiff/ChainRules.jl/issues/475
|
|
0 commit comments