@@ -167,35 +167,16 @@ let
167
167
# literal_pow is in base.jl
168
168
function frule ((_, Δx, Δp), :: typeof (^ ), x:: Number , p:: Number )
169
169
y = x ^ p
170
- # thegrad = (p * y / x)
171
- # thelog = Δp isa AbstractZero ? Δp : log(oftype(y, x))
172
- # return y, muladd(y * thelog, Δp, thegrad * Δx)
173
- # end
174
- # function frule((_, Δx, Δp), ::typeof(^), x::Real, p::Real)
175
- # y = x ^ p
176
170
thegrad = _pow_grad_x (x, p, float (y))
177
171
thelog = if Δp isa AbstractZero
178
172
# Then don't waste time computing log
179
173
Δp
180
- else # if x isa Real && p isa Real
181
- # For positive x we'd like a real answer, including any Δp.
182
- # For negative x, this is a DomainError unless isinteger(p)...
183
-
184
- # could decide that implues that p is non-differentiable:
185
- # log(ifelse(x<0, one(x), x))
186
-
187
- # or we could match what the rrule with ProjectTo gives:
188
- real (log (complex (x)))
189
- #=
190
-
191
- julia> frule((0,0,1), ^, -4, 3.0), unthunk.(rrule(^, -4, 3.0)[2](1))
192
- ((-64.0, 0.0), (NoTangent(), 48.0, -88.722839111673))
193
-
194
- julia> frule((0,0,1), ^, 4, 3.0), unthunk.(rrule(^, 4, 3.0)[2](1))
195
- ((64.0, 88.722839111673), (NoTangent(), 48.0, 88.722839111673))
196
- =#
174
+ else
175
+ # When x < 0 && isinteger(p), could decide p is non-differentiable,
176
+ # isolated points, or could match what the rrule with ProjectTo gives:
177
+ _pow_grad_p (x, p, float (y))
197
178
end
198
- return y, muladd (y * thelog, Δp, thegrad * Δx)
179
+ return y, muladd (thelog, Δp, thegrad * Δx)
199
180
end
200
181
201
182
function rrule (:: typeof (^ ), x:: Number , p:: Number )
@@ -208,6 +189,7 @@ julia> frule((0,0,1), ^, 4, 3.0), unthunk.(rrule(^, 4, 3.0)[2](1))
208
189
end
209
190
return y, power_pullback
210
191
end
192
+
211
193
_pow_grad_x (x, p, y) = (p * y / x)
212
194
function _pow_grad_x (x:: Real , p:: Real , y)
213
195
return ifelse (! iszero (x) | (p< 0 ), (p * y / x),
@@ -220,7 +202,6 @@ julia> frule((0,0,1), ^, 4, 3.0), unthunk.(rrule(^, 4, 3.0)[2](1))
220
202
ifelse (p> 0 , zero (y), oftype (y, NaN ) ))
221
203
end
222
204
223
-
224
205
@scalar_rule (
225
206
rem (x, y),
226
207
@setup ((u, nan) = promote (x / y, NaN16 ), isint = isinteger (x / y)),
0 commit comments