Skip to content

Commit 8c718df

Browse files
committed
allow integer x
1 parent 29b0839 commit 8c718df

File tree

2 files changed

+25
-18
lines changed

2 files changed

+25
-18
lines changed

src/rulesets/Base/fastmath_able.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,7 @@ let
173173
# end
174174
# function frule((_, Δx, Δp), ::typeof(^), x::Real, p::Real)
175175
# y = x ^ p
176-
thegrad = _pow_grad_x(x, p, y)
176+
thegrad = _pow_grad_x(x, p, float(y))
177177
thelog = if Δp isa AbstractZero
178178
# Then don't waste time computing log
179179
Δp
@@ -199,11 +199,11 @@ julia> frule((0,0,1), ^, 4, 3.0), unthunk.(rrule(^, 4, 3.0)[2](1))
199199
end
200200

201201
function rrule(::typeof(^), x::Number, p::Number)
202-
y = float(x^p)
202+
y = x^p
203203
project_x, project_p = ProjectTo(x), ProjectTo(p)
204204
@inline function power_pullback(dy)
205-
dx = project_x(conj(_pow_grad_x(x,p,y)) * dy)
206-
dp = @thunk project_p(conj(_pow_grad_p(x,p,y)) * dy)
205+
dx = project_x(conj(_pow_grad_x(x,p,float(y))) * dy)
206+
dp = @thunk project_p(conj(_pow_grad_p(x,p,float(y))) * dy)
207207
return (NoTangent(), dx, dp)
208208
end
209209
return y, power_pullback

test/rulesets/Base/fastmath_able.jl

Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -195,27 +195,37 @@ const FASTABLE_AST = quote
195195
(-0.0, -1) => (-Inf, NaN),
196196
(0.0, -2) => (-Inf, NaN),
197197
(-0.0, -2) => (Inf, NaN),
198+
# Integer x & p, check no InexactErrors
199+
(0, 2) => (0.0, 0.0),
200+
(0, 1) => (1.0, 0.0),
201+
(0, 0) => (0.0, NaN),
202+
(0, -1) => (-Inf, NaN),
203+
(0, -2) => (-Inf, NaN),
198204
# Non-integer powers:
199205
(0.0, 0.5) => (Inf, 0.0),
200206
(0.0, 3.5) => (0.0, 0.0),
201207
(0.0, -1.5) => (-Inf, NaN),
202208
]
203209

204210
@testset "$x ^ $p" for ((x,p), (∂x, ∂p)) in POWERGRADS
211+
if x isa Integer && p isa Integer && p < 0
212+
@test_throws DomainError x^p
213+
continue
214+
end
205215
y = x^p
206216

207217
# Forward
208-
y_f = frule((1,1,1), ^, x, p)[1]
209-
@test isequal(y, y_f) # || println("^ forward value for $x^$p: got $y_f, expected $y")
218+
y_fwd = frule((1,1,1), ^, x, p)[1]
219+
@test y === y_fwd # || println("^ forward value for $x^$p: got $y_fwd, expected $y")
210220

211-
∂x_fwd = frule((0,1,0), ^, x, p)[1]
212-
∂p_fwd = frule((0,0,1), ^, x, p)[2]
221+
# ∂x_fwd = frule((0,1,0), ^, x, p)[1]
222+
# ∂p_fwd = frule((0,0,1), ^, x, p)[2]
213223
# isequal(∂x, ∂x_fwd) || println("^ forward `x` gradient for $y = $x^$p: got $∂x_fwd, expected $∂x, maybe!")
214224
# isequal(∂p, ∂p_fwd) || println("^ forward `p` gradient for $x^$p: got $∂p_fwd, expected $∂p, maybe")
215225

216226
# Reverse
217-
y_r = rrule(^, x, p)[1]
218-
@test isequal(y, y_r) # || println("^ reverse value for $x^$p: got $y_r, expected $y")
227+
y_rev = rrule(^, x, p)[1]
228+
@test y === y_rev # || println("^ reverse value for $x^$p: got $y_rev, expected $y")
219229

220230
∂x_rev, ∂p_rev = unthunk.(rrule(^, x, p)[2](1))[2:3]
221231
if ∂x === -0.0 # happens at at x === -0.0 && p === 2, ignore the sign
@@ -227,21 +237,18 @@ const FASTABLE_AST = quote
227237
end
228238

229239
@testset "literal_pow $x ^ $p" for ((x,p), (∂x, ∂p)) in POWERGRADS
230-
# p isa Int || continue
231-
# x isa Real || continue
232-
233-
y = x^p
240+
y = Base.literal_pow(^, x, Val(p))
234241

235242
# Forward
236-
y_f = frule((1,1,1,1), Base.literal_pow, ^, x, Val(p))[1]
237-
@test isequal(y, y_f) # || println("literal_pow forward value for $x^$p: got $y_f, expected $y")
243+
y_fwd = frule((1,1,1,1), Base.literal_pow, ^, x, Val(p))[1]
244+
@test y === y_fwd # || println("literal_pow forward value for $x^$p: got $y_fwd, expected $y")
238245

239246
∂x_fwd = frule((0,0,1,0), Base.literal_pow, ^, x, Val(p))[1]
240247
# isequal(∂x, ∂x_fwd) || println("literal_pow forward `x` gradient for $x^$p: got $∂x_fwd, expected $∂x, maybe, y=$y")
241248

242249
# Reverse
243-
y_r = rrule(Base.literal_pow, ^, x, Val(p))[1]
244-
@test isequal(y, y_r) # || println("literal_pow reverse value for $x^$p: got $y_r, expected $y")
250+
y_rev = rrule(Base.literal_pow, ^, x, Val(p))[1]
251+
@test y === y_rev # || println("literal_pow reverse value for $x^$p: got $y_rev, expected $y")
245252

246253
∂x_rev = unthunk(rrule(Base.literal_pow, ^, x, Val(p))[2](1))[3]
247254
@test isequal(∂x, ∂x_rev) # || println("literal_pow `x` gradient for $x^$p: got $∂x_rev, expected $∂x")

0 commit comments

Comments
 (0)