@@ -199,11 +199,11 @@ julia> frule((0,0,1), ^, 4, 3.0), unthunk.(rrule(^, 4, 3.0)[2](1))
199
199
end
200
200
201
201
function rrule (:: typeof (^ ), x:: Number , p:: Number )
202
- y = x^ p
202
+ y = float ( x^ p)
203
203
project_x, project_p = ProjectTo (x), ProjectTo (p)
204
204
@inline function power_pullback (dy)
205
205
dx = project_x (conj (_pow_grad_x (x,p,y)) * dy)
206
- dp = @thunk project_p (conj (y * log ( complex (x) )) * dy)
206
+ dp = @thunk project_p (conj (_pow_grad_p (x,p,y )) * dy)
207
207
return (NoTangent (), dx, dp)
208
208
end
209
209
return y, power_pullback
@@ -214,6 +214,12 @@ julia> frule((0,0,1), ^, 4, 3.0), unthunk.(rrule(^, 4, 3.0)[2](1))
214
214
ifelse (isone (p), one (y),
215
215
ifelse (0 < p< 1 , oftype (y, Inf ), zero (y) )))
216
216
end
217
+ _pow_grad_p (x, p, y) = y * log (complex (x))
218
+ function _pow_grad_p (x:: Real , p:: Real , y)
219
+ return ifelse (! iszero (x), y * real (log (complex (x))),
220
+ ifelse (p> 0 , zero (y), oftype (y, NaN ) ))
221
+ end
222
+
217
223
218
224
@scalar_rule (
219
225
rem (x, y),
0 commit comments