|
42 | 42 | test_frule(sum, abs2, x; fkwargs=(;dims=dims))
|
43 | 43 | test_rrule(sum, abs2, x; fkwargs=(;dims=dims))
|
44 | 44 | end
|
45 |
| - end |
46 | 45 |
|
47 |
| - # Boolean -- via @non_differentiable, test that this isn't ambiguous |
48 |
| - test_rrule(sum, abs2, randn(5) .> 0; fkwargs=(;dims=dims)) |
| 46 | + # Boolean -- via @non_differentiable, test that this isn't ambiguous |
| 47 | + test_rrule(sum, abs2, randn(5) .> 0; fkwargs=(;dims=dims)) |
| 48 | + end |
49 | 49 | end # sum abs2
|
50 | 50 |
|
51 | 51 | @testset "sum(f, xs)" begin
|
52 | 52 | # This calls back into AD
|
53 | 53 | test_rrule(sum, abs, [-4.0, 2.0, 2.0])
|
| 54 | + test_rrule(sum, cbrt, randn(5)) |
54 | 55 | test_rrule(sum, Multiplier(2.0), [2.0, 4.0, 8.0])
|
55 | 56 |
|
| 57 | + # Complex numbers |
| 58 | + test_rrule(sum, sqrt, rand(ComplexF64, 5)) |
| 59 | + test_rrule(sum, abs, rand(ComplexF64, 3, 4)) # complex -> real |
| 60 | + |
56 | 61 | # inference fails for array of arrays
|
57 | 62 | test_rrule(sum, sum, [[2.0, 4.0], [4.0,1.9]]; check_inferred=false)
|
58 | 63 |
|
59 | 64 | # dims kwarg
|
60 | 65 | test_rrule(sum, abs, [-2.0 4.0; 5.0 1.9]; fkwargs=(;dims=1))
|
61 | 66 | test_rrule(sum, abs, [-2.0 4.0; 5.0 1.9]; fkwargs=(;dims=2))
|
| 67 | + test_rrule(sum, sqrt, rand(ComplexF64, 3, 4); fkwargs=(;dims=(1,))) |
62 | 68 |
|
63 | 69 | test_rrule(sum, abs, @SVector[1.0, -3.0])
|
64 | 70 |
|
|
78 | 84 | @test pb(1.0)[3] isa Diagonal
|
79 | 85 |
|
80 | 86 | # Boolean -- via @non_differentiable, test that this isn't ambiguous
|
81 |
| - @test_skip test_rrule(sum, sqrt, randn(5) .> 0; fkwargs=(;dims=dims)) # MethodError: no method matching real(::NoTangent) |
| 87 | + test_rrule(sum, sqrt, randn(5) .> 0) |
| 88 | + test_rrule(sum, sqrt, randn(5,5) .> 0; fkwargs=(;dims=1)) |
| 89 | + # ... and Bool produced by function |
| 90 | + @test_skip test_rrule(sum, iszero, randn(5)) # DimensionMismatch("second dimension of A, 1, does not match length of x, 0") |
82 | 91 | end
|
83 | 92 |
|
84 | 93 | @testset "prod" begin
|
|
0 commit comments