@@ -167,16 +167,11 @@ let
167
167
# literal_pow is in base.jl
168
168
function frule ((_, Δx, Δp), :: typeof (^ ), x:: Number , p:: Number )
169
169
y = x ^ p
170
- thegrad = _pow_grad_x (x, p, float (y))
171
- thelog = if Δp isa AbstractZero
172
- # Then don't waste time computing log
173
- Δp
174
- else
175
- # When x < 0 && isinteger(p), could decide p is non-differentiable,
176
- # isolated points, or could match what the rrule with ProjectTo gives:
177
- _pow_grad_p (x, p, float (y))
178
- end
179
- return y, muladd (thelog, Δp, thegrad * Δx)
170
+ dx = _pow_grad_x (x, p, float (y))
171
+ # When x < 0 && isinteger(p), could decide p is non-differentiable, isolated
172
+ # points, but chose to match what the rrule with ProjectTo gives, real(log(...)):
173
+ dp = Δp isa AbstractZero ? Δp : _pow_grad_p (x, p, float (y))
174
+ return y, muladd (dp, Δp, dx * Δx)
180
175
end
181
176
182
177
function rrule (:: typeof (^ ), x:: Number , p:: Number )
190
185
return y, power_pullback
191
186
end
192
187
193
- _pow_grad_x (x, p, y) = (p * y / x)
194
- function _pow_grad_x (x:: Real , p:: Real , y)
195
- return ifelse (! iszero (x) | (p< 0 ), (p * y / x),
196
- ifelse (isone (p), one (y),
197
- ifelse (0 < p< 1 , oftype (y, Inf ), zero (y) )))
198
- end
199
- _pow_grad_p (x, p, y) = y * log (complex (x))
200
- function _pow_grad_p (x:: Real , p:: Real , y)
201
- return ifelse (! iszero (x), y * real (log (complex (x))),
202
- ifelse (p> 0 , zero (y), oftype (y, NaN ) ))
203
- end
204
-
205
188
@scalar_rule (
206
189
rem (x, y),
207
190
@setup ((u, nan) = promote (x / y, NaN16 ), isint = isinteger (x / y)),
@@ -263,14 +246,51 @@ let
263
246
non_transformed_definitions = intersect (fastable_ast. args, fast_ast. args)
264
247
filter! (expr-> ! (expr isa LineNumberNode), non_transformed_definitions)
265
248
if ! isempty (non_transformed_definitions)
266
- @warn (
249
+ @error (
267
250
" Non-FastMath compatible rules defined in fastmath_able.jl." , # \n Definitions:\n" *
268
251
# join(non_transformed_definitions, "\n")
269
252
non_transformed_definitions
270
253
)
254
+ # This is @error not error() because that doesn't play well with Revise, locally
271
255
end
272
256
273
257
eval (fast_ast)
274
258
eval (fastable_ast) # Get original definitions
275
259
# we do this second so it overwrites anything we included by mistake in the fastable
276
260
end
261
+
262
+ # # power
263
+ # Thes functions need to be defined outside the eval() block.
264
+ # The special cases they aim to hit are in POWERGRADS in tests.
265
+ _pow_grad_x (x, p, y) = (p * y / x)
266
+ # function _pow_grad_x(x::Real, p::Real, y)
267
+ # return ifelse(!iszero(x) | (p<0), (p * y / x),
268
+ # ifelse(isone(p), one(y),
269
+ # ifelse((0<p) | (p<1), oftype(y, Inf), zero(y) )))
270
+ # end
271
+ function _pow_grad_x (x:: Real , p:: Real , y)
272
+ return if ! iszero (x) || p < 0
273
+ p * y / x
274
+ elseif isone (p)
275
+ one (y)
276
+ elseif iszero (p) || p > 1
277
+ zero (y)
278
+ else
279
+ oftype (y, Inf )
280
+ end
281
+ end
282
+
283
+ _pow_grad_p (x, p, y) = y * log (complex (x))
284
+ # function _pow_grad_p(x::Real, p::Real, y)
285
+ # return ifelse(!iszero(x), y * real(log(complex(x))),
286
+ # ifelse(p>0, zero(y), oftype(y, NaN) ))
287
+ # end
288
+ function _pow_grad_p (x:: Real , p:: Real , y)
289
+ return if ! iszero (x)
290
+ y * real (log (complex (x)))
291
+ elseif p > 0
292
+ zero (y)
293
+ else
294
+ oftype (y, NaN )
295
+ end
296
+ end
0 commit comments