@@ -166,11 +166,62 @@ const FASTABLE_AST = quote
166
166
# then x must be positive to avoid a DomainError
167
167
test_frule (^ , rand (T) + 3 , rand (T) + 3 )
168
168
test_rrule (^ , rand (T) + 3 , rand (T) + 3 )
169
+
170
+ T <: Real && S <: Real && continue
171
+ test_frule (^ , randn (T), rand (T))
172
+ test_rrule (^ , rand (T), rand (T))
169
173
end
174
+
170
175
# @testset "^(x::$T, $p::Int)" for T in (Float64, ComplexF64), p in -2:2
171
- # x = rand(T) .+ 3
176
+ # test_frule(^, randn(T) + 3, p ⊢ NoTangent()) # this doesn't just skip p's tangent
177
+ # test_rrule(^, randn(T) + 3, p ⊢ NoTangent())
172
178
# end
173
179
180
+ @testset " ^(x::Float64, p::$S ) near x=0, p=1,0,-1,-2" for S in (Int, Float64)
181
+ p = S (+ 2 )
182
+ @test frule ((1 ,1 ,1 ), ^ , 0.0 , p)[1 ] == 0
183
+ @test_broken frule ((1 ,1 ,1 ), ^ , 0.0 , p)[2 ] == 0
184
+ @test rrule (^ , 0.0 , p)[1 ] == 0
185
+ @test unthunk (rrule (^ , 0.0 , p)[2 ](1.0 )[2 ]) == 0
186
+
187
+ # Identity function x^1, at zero
188
+ p = S (+ 1 )
189
+ @test frule ((1 ,1 ,1 ), ^ , 0.0 , p)[1 ] == 0
190
+ @test_broken frule ((1 ,1 ,1 ), ^ , 0.0 , p)[2 ] == 1
191
+ @test rrule (^ , 0.0 , p)[1 ] == 0
192
+ @test unthunk (rrule (^ , 0.0 , p)[2 ](1.0 )[2 ]) == 1
193
+
194
+ # Trivial singularity: 0^0 == 1 in Julia
195
+ p = S (0 )
196
+ @test_skip frule ((1 ,1 ,1 ), ^ , 0.0 , p)[1 ] == (0.0 )^ 0
197
+ @test_broken frule ((1 ,1 ,1 ), ^ , 0.0 , p)[2 ] == 0
198
+ @test_broken unthunk (rrule (^ , 0.0 , p)[2 ](1.0 )[3 ]) == 0.0
199
+
200
+ # Odd power, 1/x
201
+ p = S (- 1 )
202
+ @test_skip frule ((1 ,1 ,1 ), ^ , 0.0 , p)[1 ] == (0.0 )^- 1
203
+ @test_broken frule ((1 ,1 ,1 ), ^ , 0.0 , p)[2 ] == - Inf
204
+ @test_skip rrule (^ , 0.0 , p)[1 ] == (0.0 )^- 1 == Inf
205
+ @test unthunk (rrule (^ , 0.0 , p)[2 ](1.0 )[2 ]) == - Inf
206
+
207
+ @test_skip frule ((1 ,1 ,1 ), ^ , - 0.0 , p)[1 ] == (- 0.0 )^- 1
208
+ @test_broken frule ((1 ,1 ,1 ), ^ , - 0.0 , p)[2 ] == - Inf
209
+ @test_skip rrule (^ , - 0.0 , p)[1 ] == (- 0.0 )^- 1 == - Inf
210
+ @test unthunk (rrule (^ , - 0.0 , p)[2 ](1.0 )[2 ]) == - Inf
211
+
212
+ # Even power, 1/x^2
213
+ p = S (- 2 )
214
+ @test_skip frule ((1 ,1 ,1 ), ^ , 0.0 , p)[1 ] == (0.0 )^- 2
215
+ @test_broken frule ((1 ,1 ,1 ), ^ , 0.0 , p)[2 ] == - Inf
216
+ @test_skip rrule (^ , 0.0 , p)[1 ] == (0.0 )^- 2 == Inf
217
+ @test unthunk (rrule (^ , 0.0 , p)[2 ](1.0 )[2 ]) == - Inf
218
+
219
+ @test_skip frule ((1 ,1 ,1 ), ^ , - 0.0 , p)[1 ] == (- 0.0 )^- 2
220
+ @test_broken frule ((1 ,1 ,1 ), ^ , - 0.0 , p)[2 ] == + Inf
221
+ @test_skip rrule (^ , - 0.0 , p)[1 ] == (- 0.0 )^- 2 == Inf
222
+ @test unthunk (rrule (^ , - 0.0 , p)[2 ](1.0 )[2 ]) == + Inf
223
+ end
224
+
174
225
# T <: Real && @testset "discontinuity for ^(x::Real, n::Int) when x ≤ 0" begin
175
226
# # finite differences doesn't work for x < 0, so we check manually
176
227
# x = -rand(T) .- 3
@@ -189,15 +240,6 @@ const FASTABLE_AST = quote
189
240
# @test ∂y ≈ 0
190
241
# end
191
242
# end
192
-
193
- # @testset "edge cases with ^" begin
194
- # # FIXME
195
- # @test_skip test_frule(^, 0.0, rand() + 3 ⊢ NoTangent(); fdm=forward_fdm(5,1))
196
- # test_rrule(^, 0.0, rand() + 3; fdm=forward_fdm(5,1))
197
-
198
- # test_frule(^, 0.0, 1.0 ⊢ NoTangent(); fdm=forward_fdm(5,1))
199
- # test_rrule(^, 0.0, 1.0; fdm=forward_fdm(5,1))
200
- # end
201
243
end
202
244
203
245
@testset " sign" begin
0 commit comments