Skip to content

Commit 1fe125e

Browse files
committed
tidy
1 parent 936cb21 commit 1fe125e

File tree

2 files changed

+20
-27
lines changed

2 files changed

+20
-27
lines changed

src/rulesets/Base/fastmath_able.jl

Lines changed: 15 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -167,38 +167,44 @@ let
167167
# literal_pow is in base.jl
168168
function frule((_, Δx, Δp), ::typeof(^), x::Number, p::Number)
169169
y = x ^ p
170-
dx = _pow_grad_x(x, p, float(y))
170+
_dx = _pow_grad_x(x, p, float(y))
171171
# When x < 0 && isinteger(p), could decide p is non-differentiable, isolated
172172
# points, but chose to match what the rrule with ProjectTo gives, real(log(...)):
173-
dp = Δp isa AbstractZero ? Δp : _pow_grad_p(x, p, float(y))
174-
return y, muladd(dp, Δp, dx * Δx)
173+
_dp = Δp isa AbstractZero ? Δp : _pow_grad_p(x, p, float(y))
174+
return y, muladd(_dp, Δp, _dx * Δx)
175175
end
176176

177177
function rrule(::typeof(^), x::Number, p::Number)
178178
y = x^p
179-
project_x, project_p = ProjectTo(x), ProjectTo(p)
179+
project_x = ProjectTo(x)
180+
project_p = ProjectTo(p)
180181
@inline function power_pullback(dy)
181-
dx = project_x(conj(_pow_grad_x(x,p,float(y))) * dy)
182-
dp = @thunk project_p(conj(_pow_grad_p(x,p,float(y))) * dy)
183-
return (NoTangent(), dx, dp)
182+
_dx = _pow_grad_x(x, p, float(y))
183+
_dy = _pow_grad_p(x, p, float(y))
184+
return (
185+
NoTangent(),
186+
project_x(conj(_dx) * dy),
187+
@thunk project_p(conj(_dy) * dy)
188+
)
184189
end
185190
return y, power_pullback
186191
end
187192

193+
## `rem`
188194
@scalar_rule(
189195
rem(x, y),
190196
@setup((u, nan) = promote(x / y, NaN16), isint = isinteger(x / y)),
191197
(ifelse(isint, nan, one(u)), ifelse(isint, nan, -trunc(u))),
192198
)
199+
## `min`, `max`
193200
@scalar_rule max(x, y) @setup(gt = x > y) (gt, !gt)
194201
@scalar_rule min(x, y) @setup(gt = x > y) (!gt, gt)
195202

196203
# Unary functions
197204
@scalar_rule +x true
198205
@scalar_rule -x -1
199206

200-
# `sign`
201-
207+
## `sign`
202208
function frule((_, Δx), ::typeof(sign), x)
203209
n = ifelse(iszero(x), one(real(x)), abs(x))
204210
Ω = x isa Real ? sign(x) : x / n
@@ -263,11 +269,6 @@ end
263269
# Thes functions need to be defined outside the eval() block.
264270
# The special cases they aim to hit are in POWERGRADS in tests.
265271
_pow_grad_x(x, p, y) = (p * y / x)
266-
# function _pow_grad_x(x::Real, p::Real, y)
267-
# return ifelse(!iszero(x) | (p<0), (p * y / x),
268-
# ifelse(isone(p), one(y),
269-
# ifelse((0<p) | (p<1), oftype(y, Inf), zero(y) )))
270-
# end
271272
function _pow_grad_x(x::Real, p::Real, y)
272273
return if !iszero(x) || p < 0
273274
p * y / x
@@ -281,10 +282,6 @@ function _pow_grad_x(x::Real, p::Real, y)
281282
end
282283

283284
_pow_grad_p(x, p, y) = y * log(complex(x))
284-
# function _pow_grad_p(x::Real, p::Real, y)
285-
# return ifelse(!iszero(x), y * real(log(complex(x))),
286-
# ifelse(p>0, zero(y), oftype(y, NaN) ))
287-
# end
288285
function _pow_grad_p(x::Real, p::Real, y)
289286
return if !iszero(x)
290287
y * real(log(complex(x)))

test/rulesets/Base/fastmath_able.jl

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ const FASTABLE_AST = quote
180180
# At x=0, gradients for x seem clear,
181181
# for p less certain what's best.
182182
(0.0, 2) => (0.0, 0.0),
183-
(-0.0, 2) => (-0.0, 0.0),
183+
(-0.0, 2) => (0.0, 0.0), # probably (-0.0, 0.0) would be ideal
184184
(0.0, 1) => (1.0, 0.0),
185185
(-0.0, 1) => (1.0, 0.0),
186186
(0.0, 0) => (0.0, NaN),
@@ -210,7 +210,7 @@ const FASTABLE_AST = quote
210210

211211
# Forward
212212
y_fwd = frule((1,1,1), ^, x, p)[1]
213-
@test y === y_fwd # || println("^ forward value for $x^$p: got $y_fwd, expected $y")
213+
@test isequal(y, y_fwd)
214214

215215
# ∂x_fwd = frule((0,1,0), ^, x, p)[1]
216216
# ∂p_fwd = frule((0,0,1), ^, x, p)[2]
@@ -219,15 +219,11 @@ const FASTABLE_AST = quote
219219

220220
# Reverse
221221
y_rev = rrule(^, x, p)[1]
222-
@test y === y_rev # || println("^ reverse value for $x^$p: got $y_rev, expected $y")
222+
@test isequal(y, y_rev)
223223

224224
∂x_rev, ∂p_rev = unthunk.(rrule(^, x, p)[2](1))[2:3]
225-
if ∂x === -0.0 # happens at at x === -0.0 && p === 2, ignore the sign
226-
@test 0.0 == ∂x_rev
227-
else
228-
@test isequal(∂x, ∂x_rev) # || println("^ reverse `x` gradient for $x^$p: got $∂x_rev, expected $∂x")
229-
end
230-
@test isequal(∂p, ∂p_rev) # || println("^ reverse `p` gradient for $x^$p: got $∂p_rev, expected $∂p")
225+
@test isequal(∂x, ∂x_rev)
226+
@test isequal(∂p, ∂p_rev)
231227
end
232228
end
233229

0 commit comments

Comments
 (0)