@@ -138,8 +138,6 @@ const FASTABLE_AST = quote
138
138
end
139
139
140
140
@testset " $f (x::$T , y::$T ) type check" for f in (/ , + , - ,\ , hypot), T in (Float32, Float64)
141
- # ^ removed for now!
142
-
143
141
x, Δx, x̄ = 10 rand (T, 3 )
144
142
y, Δy, ȳ = rand (T, 3 )
145
143
@assert T == typeof (f (x, y))
@@ -162,12 +160,14 @@ const FASTABLE_AST = quote
162
160
end
163
161
164
162
@testset " ^(x::$T , p::$S )" for T in (Float64, ComplexF64), S in (Float64, ComplexF64)
165
- # When both x & p are Real, and !(isinteger(p)),
166
- # then x must be positive to avoid a DomainError
167
163
test_frule (^ , rand (T) + 3 , rand (T) + 3 )
168
164
test_rrule (^ , rand (T) + 3 , rand (T) + 3 )
169
-
165
+
166
+ # When both x & p are Real, and !(isinteger(p)),
167
+ # then x must be positive to avoid a DomainError
170
168
T <: Real && S <: Real && continue
169
+ # In other cases, we can test values near zero:
170
+
171
171
test_frule (^ , randn (T), rand (T))
172
172
test_rrule (^ , rand (T), rand (T))
173
173
end
@@ -177,77 +177,13 @@ const FASTABLE_AST = quote
177
177
# test_rrule(^, randn(T) + 3, p ⊢ NoTangent())
178
178
# end
179
179
180
- # @testset "^(x::Float64, p::$S) near x=0, p=1,0,-1,-2" for S in (Int, Float64)
181
- # # x^2. Easy to get NaN here by mistake.
182
- # p = S(+2)
183
- # @test frule((1,1,1), ^, 0.0, p)[1] == 0 # value
184
- # @test_broken frule((1,1,1), ^, 0.0, p)[2] == 0 # gradient, forwards
185
- # @test rrule(^, 0.0, p)[1] == 0 # value
186
- # @test unthunk(rrule(^, 0.0, p)[2](1.0)[2]) == 0 # gradient, reverse
187
-
188
- # # Identity function x^1, at zero
189
- # p = S(+1)
190
- # @test frule((1,1,1), ^, 0.0, p)[1] == 0
191
- # @test_broken frule((1,1,1), ^, 0.0, p)[2] == 1
192
- # @test rrule(^, 0.0, p)[1] == 0
193
- # @test unthunk(rrule(^, 0.0, p)[2](1.0)[2]) == 1
194
-
195
- # # Trivial singularity: 0^0 == 1 in Julia
196
- # p = S(0)
197
- # @test_skip frule((1,1,1), ^, 0.0, p)[1] == (0.0)^0
198
- # @test_broken frule((1,1,1), ^, 0.0, p)[2] == 0
199
- # @test_broken unthunk(rrule(^, 0.0, p)[2](1.0)[3]) == 0.0
200
-
201
- # # Odd power, 1/x
202
- # p = S(-1)
203
- # @test_skip frule((1,1,1), ^, 0.0, p)[1] == (0.0)^-1
204
- # @test_broken frule((1,1,1), ^, 0.0, p)[2] == -Inf
205
- # @test_skip rrule(^, 0.0, p)[1] == (0.0)^-1 == Inf
206
- # @test unthunk(rrule(^, 0.0, p)[2](1.0)[2]) == -Inf
207
-
208
- # @test_skip frule((1,1,1), ^, -0.0, p)[1] == (-0.0)^-1
209
- # @test_broken frule((1,1,1), ^, -0.0, p)[2] == -Inf
210
- # @test_skip rrule(^, -0.0, p)[1] == (-0.0)^-1 == -Inf
211
- # @test unthunk(rrule(^, -0.0, p)[2](1.0)[2]) == -Inf
212
-
213
- # # Even power, 1/x^2
214
- # p = S(-2)
215
- # @test_skip frule((1,1,1), ^, 0.0, p)[1] == (0.0)^-2
216
- # @test_broken frule((1,1,1), ^, 0.0, p)[2] == -Inf
217
- # @test_skip rrule(^, 0.0, p)[1] == (0.0)^-2 == Inf
218
- # @test unthunk(rrule(^, 0.0, p)[2](1.0)[2]) == -Inf
219
-
220
- # @test_skip frule((1,1,1), ^, -0.0, p)[1] == (-0.0)^-2
221
- # @test_broken frule((1,1,1), ^, -0.0, p)[2] == +Inf
222
- # @test_skip rrule(^, -0.0, p)[1] == (-0.0)^-2 == Inf
223
- # @test unthunk(rrule(^, -0.0, p)[2](1.0)[2]) == +Inf
224
- # end
225
-
226
- # T <: Real && @testset "discontinuity for ^(x::Real, n::Int) when x ≤ 0" begin
227
- # # finite differences doesn't work for x < 0, so we check manually
228
- # x = -rand(T) .- 3
229
- # y = 3
230
- # Δx = randn(T)
231
- # Δy = randn(T)
232
- # Δz = randn(T)
233
-
234
- # @test frule((ZeroTangent(), Δx, Δy), ^, x, y)[2] ≈ Δx * y * x^(y - 1)
235
- # @test frule((ZeroTangent(), Δx, Δy), ^, zero(x), y)[2] ≈ 0
236
- # _, ∂x, ∂y = rrule(^, x, y)[2](Δz)
237
- # @test ∂x ≈ Δz * y * x^(y - 1)
238
- # @test ∂y ≈ 0
239
- # _, ∂x, ∂y = rrule(^, zero(x), y)[2](Δz)
240
- # @test ∂x ≈ 0
241
- # @test ∂y ≈ 0
242
- # end
243
- # end
244
- end
180
+ # Tests for power functions, at values near to zero.
245
181
246
182
POWERGRADS = [ # (x,p) => (dx,dp)
247
- # some regular points, sanity checks
183
+ # Some regular points, sanity checks
248
184
(1.0 , 2 ) => (2.0 , 0.0 ),
249
185
(2.0 , 2 ) => (4.0 , 2.772588722239781 ),
250
- # at x=0, gradients for x seem clear,
186
+ # At x=0, gradients for x seem clear,
251
187
# for p I've just written here what it gives
252
188
(0.0 , 2 ) => (0.0 , NaN ),
253
189
(- 0.0 , 2 ) => (- 0.0 , NaN ),
@@ -259,74 +195,60 @@ POWERGRADS = [ # (x,p) => (dx,dp)
259
195
(- 0.0 , - 1 ) => (- Inf , Inf ),
260
196
(0.0 , - 2 ) => (- Inf , - Inf ),
261
197
(- 0.0 , - 2 ) => (Inf , - Inf ),
262
- # non -integer powers
198
+ # Non -integer powers:
263
199
(0.0 , 0.5 ) => (Inf , NaN ),
264
200
(0.0 , 3.5 ) => (0.0 , NaN ),
265
-
266
201
]
267
- for ((x,p), (gx, gp)) in POWERGRADS
202
+
203
+ for ((x,p), (gx, gp)) in POWERGRADS # power ^
268
204
y = x^ p
269
205
206
+ # Forward
270
207
y_f = frule ((1 ,1 ,1 ), ^ , x, p)[1 ]
271
208
isequal (y, y_f) || println (" ^ forward value for $x ^$p : got $y_f , expected $y " )
272
209
273
- y_r = rrule (^ , x, p)[1 ]
274
- isequal (y, y_r) || println (" ^ reverse value for $x ^$p : got $y_r , expected $y " )
275
-
276
210
gx_f = frule ((0 ,1 ,0 ), ^ , x, p)[1 ]
277
211
gp_f = frule ((0 ,0 ,1 ), ^ , x, p)[2 ]
278
212
# isequal(gx, gx_f) || println("^ forward `x` gradient for $x^$p: got $gx_f, expected $gx, maybe")
279
213
# isequal(gp, gp_f) || println("^ forward `p` gradient for $x^$p: got $gp_f, expected $gp, maybe")
280
214
215
+ # Reverse
216
+ y_r = rrule (^ , x, p)[1 ]
217
+ isequal (y, y_r) || println (" ^ reverse value for $x ^$p : got $y_r , expected $y " )
218
+
281
219
gx_r, gp_r = unthunk .(rrule (^ , x, p)[2 ](1 ))[2 : 3 ]
282
- isequal (gx, gx_r) || println (" ^ reverse `x` gradient for $x ^$p : got $gx_r , expected $gx " )
220
+ if x === - 0.0 && p === 2
221
+ @test 0.0 == gx_r # POWERGRADS says -0.0
222
+ else
223
+ isequal (gx, gx_r) || println (" ^ reverse `x` gradient for $x ^$p : got $gx_r , expected $gx " )
224
+ end
283
225
isequal (gp, gp_r) || println (" ^ reverse `p` gradient for $x ^$p : got $gp_r , expected $gp " )
284
-
285
226
end
286
- for ((x,p), (gx, gp)) in POWERGRADS
227
+
228
+ for ((x,p), (gx, gp)) in POWERGRADS # literal_pow
287
229
p isa Int || continue
288
230
x isa Real || continue
289
231
290
232
y = x^ p
291
233
234
+ # Forward
292
235
y_f = frule ((1 ,1 ,1 ,1 ), Base. literal_pow, ^ , x, Val (p))[1 ]
293
236
isequal (y, y_f) || println (" literal_pow forward value for $x ^$p : got $y_f , expected $y " )
294
237
238
+ gx_f = frule ((0 ,0 ,1 ,0 ), Base. literal_pow, ^ , x, Val (p))[1 ]
239
+ # isequal(gx, gx_f) || println("literal_pow forward `x` gradient for $x^$p: got $gx_f, expected $gx, maybe, y=$y")
240
+
241
+ # Reverse
295
242
y_r = rrule (Base. literal_pow, ^ , x, Val (p))[1 ]
296
243
isequal (y, y_r) || println (" literal_pow reverse value for $x ^$p : got $y_r , expected $y " )
297
244
298
245
gx_r = unthunk (rrule (Base. literal_pow, ^ , x, Val (p))[2 ](1 ))[3 ]
299
246
isequal (gx, gx_r) || println (" literal_pow `x` gradient for $x ^$p : got $gx_r , expected $gx " )
300
247
301
- gx_f = frule ((0 ,0 ,1 ,0 ), Base. literal_pow, ^ , x, Val (p))[1 ]
302
- # isequal(gx, gx_f) || println("literal_pow forward `x` gradient for $x^$p: got $gx_f, expected $gx, maybe")
303
- end
304
-
305
-
306
- for x in Any[0.0 , - 0.0 , 0.0 + 0im ], p in Any[2 , 1.5 , 1 , 0.5 , 0 , - 0.5 , - 1 , - 1.5 , - 2 ]
307
-
308
- y = x^ p
309
- yr = rrule (^ , x, p)[1 ]
310
- # isequal(y, yr) || printstyled("runtime $x^$p = $y, but rrule gives $yr \n", color=:red)
311
-
312
- gx, gp = unthunk .(rrule (^ , x, p)[2 ](1 )[2 : 3 ])
313
- println (" runtime $x ^$p gradient from rrule: $gx , $gp " )
314
-
315
- p isa Int || continue # e.g. Meta.@lower x^5.0
316
- x isa Real || continue # limitation of methods here?
317
- y = Base. literal_pow (^ , x, Val (p))
318
-
319
- # yr = rrule(Base.literal_pow, ^, x, Val(p))[1]
320
- # isequal(y, yr) || printstyled("literal $x^$p = $y, but rrule gives $yr\n", color=:red)
321
-
322
- # gx = unthunk(rrule(Base.literal_pow, ^, x, Val(p))[2](1))[3]
323
- # println("literal $x^$p gradient from rrule: $gx")
324
-
325
- # gg[(x,p)] = (gx, nothing)
248
+ # @info "all" x y p gx_f gx_r
326
249
end
327
250
328
251
329
-
330
252
@testset " sign" begin
331
253
@testset " real" begin
332
254
@testset " at $x " for x in (- 1.1 , - 1.1 , 0.5 , 100.0 )
0 commit comments