Skip to content

Commit 936cb21

Browse files
committed
tidy
1 parent 48a9d14 commit 936cb21

File tree

2 files changed

+44
-24
lines changed

2 files changed

+44
-24
lines changed

src/rulesets/Base/fastmath_able.jl

Lines changed: 43 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -167,16 +167,11 @@ let
167167
# literal_pow is in base.jl
168168
function frule((_, Δx, Δp), ::typeof(^), x::Number, p::Number)
169169
y = x ^ p
170-
thegrad = _pow_grad_x(x, p, float(y))
171-
thelog = if Δp isa AbstractZero
172-
# Then don't waste time computing log
173-
Δp
174-
else
175-
# When x < 0 && isinteger(p), could decide p is non-differentiable,
176-
# isolated points, or could match what the rrule with ProjectTo gives:
177-
_pow_grad_p(x, p, float(y))
178-
end
179-
return y, muladd(thelog, Δp, thegrad * Δx)
170+
dx = _pow_grad_x(x, p, float(y))
171+
# When x < 0 && isinteger(p), could decide p is non-differentiable, isolated
172+
# points, but chose to match what the rrule with ProjectTo gives, real(log(...)):
173+
dp = Δp isa AbstractZero ? Δp : _pow_grad_p(x, p, float(y))
174+
return y, muladd(dp, Δp, dx * Δx)
180175
end
181176

182177
function rrule(::typeof(^), x::Number, p::Number)
@@ -190,18 +185,6 @@ let
190185
return y, power_pullback
191186
end
192187

193-
_pow_grad_x(x, p, y) = (p * y / x)
194-
function _pow_grad_x(x::Real, p::Real, y)
195-
return ifelse(!iszero(x) | (p<0), (p * y / x),
196-
ifelse(isone(p), one(y),
197-
ifelse(0<p<1, oftype(y, Inf), zero(y) )))
198-
end
199-
_pow_grad_p(x, p, y) = y * log(complex(x))
200-
function _pow_grad_p(x::Real, p::Real, y)
201-
return ifelse(!iszero(x), y * real(log(complex(x))),
202-
ifelse(p>0, zero(y), oftype(y, NaN) ))
203-
end
204-
205188
@scalar_rule(
206189
rem(x, y),
207190
@setup((u, nan) = promote(x / y, NaN16), isint = isinteger(x / y)),
@@ -263,14 +246,51 @@ let
263246
non_transformed_definitions = intersect(fastable_ast.args, fast_ast.args)
264247
filter!(expr->!(expr isa LineNumberNode), non_transformed_definitions)
265248
if !isempty(non_transformed_definitions)
266-
@warn(
249+
@error(
267250
"Non-FastMath compatible rules defined in fastmath_able.jl.", # \n Definitions:\n" *
268251
# join(non_transformed_definitions, "\n")
269252
non_transformed_definitions
270253
)
254+
# This is @error not error() because that doesn't play well with Revise, locally
271255
end
272256

273257
eval(fast_ast)
274258
eval(fastable_ast) # Get original definitions
275259
# we do this second so it overwrites anything we included by mistake in the fastable
276260
end
261+
262+
## power
263+
# Thes functions need to be defined outside the eval() block.
264+
# The special cases they aim to hit are in POWERGRADS in tests.
265+
_pow_grad_x(x, p, y) = (p * y / x)
266+
# function _pow_grad_x(x::Real, p::Real, y)
267+
# return ifelse(!iszero(x) | (p<0), (p * y / x),
268+
# ifelse(isone(p), one(y),
269+
# ifelse((0<p) | (p<1), oftype(y, Inf), zero(y) )))
270+
# end
271+
function _pow_grad_x(x::Real, p::Real, y)
272+
return if !iszero(x) || p < 0
273+
p * y / x
274+
elseif isone(p)
275+
one(y)
276+
elseif iszero(p) || p > 1
277+
zero(y)
278+
else
279+
oftype(y, Inf)
280+
end
281+
end
282+
283+
_pow_grad_p(x, p, y) = y * log(complex(x))
284+
# function _pow_grad_p(x::Real, p::Real, y)
285+
# return ifelse(!iszero(x), y * real(log(complex(x))),
286+
# ifelse(p>0, zero(y), oftype(y, NaN) ))
287+
# end
288+
function _pow_grad_p(x::Real, p::Real, y)
289+
return if !iszero(x)
290+
y * real(log(complex(x)))
291+
elseif p > 0
292+
zero(y)
293+
else
294+
oftype(y, NaN)
295+
end
296+
end

test/rulesets/Base/fastmath_able.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ const FASTABLE_AST = quote
178178
(1.0, 2) => (2.0, 0.0),
179179
(2.0, 2) => (4.0, 2.772588722239781),
180180
# At x=0, gradients for x seem clear,
181-
# for p less certain but I think 0 or NaN right?
181+
# for p less certain what's best.
182182
(0.0, 2) => (0.0, 0.0),
183183
(-0.0, 2) => (-0.0, 0.0),
184184
(0.0, 1) => (1.0, 0.0),

0 commit comments

Comments
 (0)