@@ -195,27 +195,37 @@ const FASTABLE_AST = quote
195
195
(- 0.0 , - 1 ) => (- Inf , NaN ),
196
196
(0.0 , - 2 ) => (- Inf , NaN ),
197
197
(- 0.0 , - 2 ) => (Inf , NaN ),
198
+ # Integer x & p, check no InexactErrors
199
+ (0 , 2 ) => (0.0 , 0.0 ),
200
+ (0 , 1 ) => (1.0 , 0.0 ),
201
+ (0 , 0 ) => (0.0 , NaN ),
202
+ (0 , - 1 ) => (- Inf , NaN ),
203
+ (0 , - 2 ) => (- Inf , NaN ),
198
204
# Non-integer powers:
199
205
(0.0 , 0.5 ) => (Inf , 0.0 ),
200
206
(0.0 , 3.5 ) => (0.0 , 0.0 ),
201
207
(0.0 , - 1.5 ) => (- Inf , NaN ),
202
208
]
203
209
204
210
@testset " $x ^ $p " for ((x,p), (∂x, ∂p)) in POWERGRADS
211
+ if x isa Integer && p isa Integer && p < 0
212
+ @test_throws DomainError x^ p
213
+ continue
214
+ end
205
215
y = x^ p
206
216
207
217
# Forward
208
- y_f = frule ((1 ,1 ,1 ), ^ , x, p)[1 ]
209
- @test isequal (y, y_f) # || println("^ forward value for $x^$p: got $y_f , expected $y")
218
+ y_fwd = frule ((1 ,1 ,1 ), ^ , x, p)[1 ]
219
+ @test y === y_fwd # || println("^ forward value for $x^$p: got $y_fwd , expected $y")
210
220
211
- ∂x_fwd = frule ((0 ,1 ,0 ), ^ , x, p)[1 ]
212
- ∂p_fwd = frule ((0 ,0 ,1 ), ^ , x, p)[2 ]
221
+ # ∂x_fwd = frule((0,1,0), ^, x, p)[1]
222
+ # ∂p_fwd = frule((0,0,1), ^, x, p)[2]
213
223
# isequal(∂x, ∂x_fwd) || println("^ forward `x` gradient for $y = $x^$p: got $∂x_fwd, expected $∂x, maybe!")
214
224
# isequal(∂p, ∂p_fwd) || println("^ forward `p` gradient for $x^$p: got $∂p_fwd, expected $∂p, maybe")
215
225
216
226
# Reverse
217
- y_r = rrule (^ , x, p)[1 ]
218
- @test isequal (y, y_r) # || println("^ reverse value for $x^$p: got $y_r , expected $y")
227
+ y_rev = rrule (^ , x, p)[1 ]
228
+ @test y === y_rev # || println("^ reverse value for $x^$p: got $y_rev , expected $y")
219
229
220
230
∂x_rev, ∂p_rev = unthunk .(rrule (^ , x, p)[2 ](1 ))[2 : 3 ]
221
231
if ∂x === - 0.0 # happens at at x === -0.0 && p === 2, ignore the sign
@@ -227,21 +237,18 @@ const FASTABLE_AST = quote
227
237
end
228
238
229
239
@testset " literal_pow $x ^ $p " for ((x,p), (∂x, ∂p)) in POWERGRADS
230
- # p isa Int || continue
231
- # x isa Real || continue
232
-
233
- y = x^ p
240
+ y = Base. literal_pow (^ , x, Val (p))
234
241
235
242
# Forward
236
- y_f = frule ((1 ,1 ,1 ,1 ), Base. literal_pow, ^ , x, Val (p))[1 ]
237
- @test isequal (y, y_f) # || println("literal_pow forward value for $x^$p: got $y_f , expected $y")
243
+ y_fwd = frule ((1 ,1 ,1 ,1 ), Base. literal_pow, ^ , x, Val (p))[1 ]
244
+ @test y === y_fwd # || println("literal_pow forward value for $x^$p: got $y_fwd , expected $y")
238
245
239
246
∂x_fwd = frule ((0 ,0 ,1 ,0 ), Base. literal_pow, ^ , x, Val (p))[1 ]
240
247
# isequal(∂x, ∂x_fwd) || println("literal_pow forward `x` gradient for $x^$p: got $∂x_fwd, expected $∂x, maybe, y=$y")
241
248
242
249
# Reverse
243
- y_r = rrule (Base. literal_pow, ^ , x, Val (p))[1 ]
244
- @test isequal (y, y_r) # || println("literal_pow reverse value for $x^$p: got $y_r , expected $y")
250
+ y_rev = rrule (Base. literal_pow, ^ , x, Val (p))[1 ]
251
+ @test y === y_rev # || println("literal_pow reverse value for $x^$p: got $y_rev , expected $y")
245
252
246
253
∂x_rev = unthunk (rrule (Base. literal_pow, ^ , x, Val (p))[2 ](1 ))[3 ]
247
254
@test isequal (∂x, ∂x_rev) # || println("literal_pow `x` gradient for $x^$p: got $∂x_rev, expected $∂x")
0 commit comments