|
98 | 98 | @testset "($x) * ($y)" for
|
99 | 99 | x in test_points, y in test_points
|
100 | 100 |
|
101 |
| - # ensure all complex if any complex for FiniteDifferences |
102 |
| - x, y = Base.promote(x, y) |
| 101 | + # all complex if any complex, was a limitation of FiniteDifferences? |
| 102 | + xx, yy = Base.promote(x, y) |
| 103 | + test_frule(*, xx, yy) |
| 104 | + test_rrule(*, xx, yy) |
103 | 105 |
|
| 106 | + # explicitly allow mixed types |
104 | 107 | test_frule(*, x, y)
|
105 | 108 | test_rrule(*, x, y)
|
| 109 | + rrule(*, x, y)[2](1)[2] isa typeof(x) |
| 110 | + rrule(*, x, y)[2](1)[3] isa typeof(y) |
106 | 111 | end
|
107 | 112 | end
|
108 | 113 |
|
|
136 | 141 | test_rrule(identity, Tuple(randn(T, 3)))
|
137 | 142 | end
|
138 | 143 |
|
139 |
| - @testset "Constants" for x in (-0.1, 6.4, 1.0+0.5im, -10.0+0im, 0.0+200im) |
| 144 | + @testset "one(::Number), zero(::Number)" for x in (-0.1, 6.4, 1.0+0.5im, -10.0+0im, 0.0+200im) |
140 | 145 | test_scalar(one, x)
|
141 | 146 | test_scalar(zero, x)
|
| 147 | + |
| 148 | + rrule(one, x)[2](1) === (NoTangent(), zero(x)) |
| 149 | + rrule(zero, x)[2](1) === (NoTangent(), zero(x)) |
142 | 150 | end
|
143 | 151 |
|
144 | 152 | @testset "muladd(x::$T, y::$T, z::$T)" for T in (Float64, ComplexF64)
|
145 |
| - test_frule(muladd, 10randn(), randn(), randn()) |
146 |
| - test_rrule(muladd, 10randn(), randn(), randn()) |
| 153 | + test_frule(muladd, 10randn(T), randn(T), randn(T)) |
| 154 | + test_rrule(muladd, 10randn(T), randn(T), randn(T)) |
147 | 155 | end
|
148 | 156 |
|
149 | 157 | @testset "fma" begin
|
|
163 | 171 | # to right
|
164 | 172 | test_frule(clamp, 4., 2., 3.)
|
165 | 173 | test_rrule(clamp, 4., 2., 3.)
|
| 174 | + |
| 175 | + # nonzero gradient at the boundaries |
| 176 | + @test frule((0,1,0,0), clamp, 2, 2, 3) == (2, 1) |
| 177 | + @test rrule(clamp, 2.0, 2, 3)[2](1)[2] == 1.0 |
| 178 | + |
| 179 | + @test frule((0,1,0,0), clamp, 3, 2, 3) == (3, 1) |
| 180 | + @test rrule(clamp, 3, 2, 3)[2](1)[2] == 1.0 |
166 | 181 | end
|
167 | 182 |
|
168 | 183 | @testset "rounding" begin
|
|
0 commit comments