Skip to content

Commit 7b70923

Browse files
committed
fix frule
1 parent 1c08573 commit 7b70923

File tree

2 files changed

+19
-8
lines changed

2 files changed

+19
-8
lines changed

src/rulesets/Base/fastmath_able.jl

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -168,10 +168,14 @@ let
168168
function frule((_, Δx, Δp), ::typeof(^), x::Number, p::Number)
169169
y = x ^ p
170170
_dx = _pow_grad_x(x, p, float(y))
171-
# When x < 0 && isinteger(p), could decide p is non-differentiable, isolated
172-
# points, but chose to match what the rrule with ProjectTo gives, real(log(...)):
173-
_dp = Δp isa AbstractZero ? Δp : _pow_grad_p(x, p, float(y))
174-
return y, muladd(_dp, Δp, _dx * Δx)
171+
if iszero(Δp)
172+
# Treat this as a strong zero, to avoid NaN, and save the cost of log
173+
return y, _dx * Δx
174+
else
175+
# This may do real(log(complex(...))) which matches ProjectTo in rrule
176+
_dp = _pow_grad_p(x, p, float(y))
177+
return y, muladd(_dp, Δp, _dx * Δx)
178+
end
175179
end
176180

177181
function rrule(::typeof(^), x::Number, p::Number)

test/rulesets/Base/fastmath_able.jl

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -212,10 +212,17 @@ const FASTABLE_AST = quote
212212
y_fwd = frule((1,1,1), ^, x, p)[1]
213213
@test isequal(y, y_fwd)
214214

215-
# ∂x_fwd = frule((0,1,0), ^, x, p)[1]
216-
# ∂p_fwd = frule((0,0,1), ^, x, p)[2]
217-
# isequal(∂x, ∂x_fwd) || println("^ forward `x` gradient for $y = $x^$p: got $∂x_fwd, expected $∂x, maybe!")
218-
# isequal(∂p, ∂p_fwd) || println("^ forward `p` gradient for $x^$p: got $∂p_fwd, expected $∂p, maybe")
215+
∂x_fwd = frule((0,1,0), ^, x, p)[2]
216+
∂p_fwd = frule((0,0,1), ^, x, p)[2]
217+
@test isequal(∂x, ∂x_fwd)
218+
if x===0.0 && p===0.5
219+
@test_broken isequal(∂p, ∂p_fwd)
220+
else
221+
@test isequal(∂p, ∂p_fwd)
222+
end
223+
224+
∂x_fwd = frule((0,1,ZeroTangent()), ^, x, p)[2] # easier, strong zero
225+
@test isequal(∂x, ∂x_fwd)
219226

220227
# Reverse
221228
y_rev = rrule(^, x, p)[1]

0 commit comments

Comments
 (0)