Skip to content

Commit 8405b68

Browse files
committed
change the p grad
1 parent 95dfb23 commit 8405b68

File tree

1 file changed

+8
-2
lines changed

1 file changed

+8
-2
lines changed

src/rulesets/Base/fastmath_able.jl

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -199,11 +199,11 @@ julia> frule((0,0,1), ^, 4, 3.0), unthunk.(rrule(^, 4, 3.0)[2](1))
199199
end
200200

201201
function rrule(::typeof(^), x::Number, p::Number)
202-
y = x^p
202+
y = float(x^p)
203203
project_x, project_p = ProjectTo(x), ProjectTo(p)
204204
@inline function power_pullback(dy)
205205
dx = project_x(conj(_pow_grad_x(x,p,y)) * dy)
206-
dp = @thunk project_p(conj(y * log(complex(x))) * dy)
206+
dp = @thunk project_p(conj(_pow_grad_p(x,p,y)) * dy)
207207
return (NoTangent(), dx, dp)
208208
end
209209
return y, power_pullback
@@ -214,6 +214,12 @@ julia> frule((0,0,1), ^, 4, 3.0), unthunk.(rrule(^, 4, 3.0)[2](1))
214214
ifelse(isone(p), one(y),
215215
ifelse(0<p<1, oftype(y, Inf), zero(y) )))
216216
end
217+
_pow_grad_p(x, p, y) = y * log(complex(x))
218+
function _pow_grad_p(x::Real, p::Real, y)
219+
return ifelse(!iszero(x), y * real(log(complex(x))),
220+
ifelse(p>0, zero(y), oftype(y, NaN) ))
221+
end
222+
217223

218224
@scalar_rule(
219225
rem(x, y),

0 commit comments

Comments
 (0)