Skip to content

Commit 95dfb23

Browse files
committed
tidy up
1 parent 504e8dc commit 95dfb23

File tree

2 files changed

+43
-124
lines changed

2 files changed

+43
-124
lines changed

src/rulesets/Base/fastmath_able.jl

Lines changed: 14 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -167,15 +167,13 @@ let
167167
# literal_pow is in base.jl
168168
function frule((_, Δx, Δp), ::typeof(^), x::Number, p::Number)
169169
y = x ^ p
170-
thegrad = (p * y / x)
171-
thelog = Δp isa AbstractZero ? Δp : log(oftype(y, x))
172-
return y, muladd(y * thelog, Δp, thegrad * Δx)
173-
end
174-
function frule((_, Δx, Δp), ::typeof(^), x::Real, p::Real)
175-
y = x ^ p
176-
thegrad = ifelse(!iszero(x) | (p<0), (p * y / x),
177-
ifelse(isone(p), one(y),
178-
ifelse(0<p<1, oftype(y, Inf), zero(y) )))
170+
# thegrad = (p * y / x)
171+
# thelog = Δp isa AbstractZero ? Δp : log(oftype(y, x))
172+
# return y, muladd(y * thelog, Δp, thegrad * Δx)
173+
# end
174+
# function frule((_, Δx, Δp), ::typeof(^), x::Real, p::Real)
175+
# y = x ^ p
176+
thegrad = _pow_grad_x(x, p, y)
179177
thelog = if Δp isa AbstractZero
180178
# Then don't waste time computing log
181179
Δp
@@ -204,19 +202,18 @@ julia> frule((0,0,1), ^, 4, 3.0), unthunk.(rrule(^, 4, 3.0)[2](1))
204202
y = x^p
205203
project_x, project_p = ProjectTo(x), ProjectTo(p)
206204
@inline function power_pullback(dy)
207-
if x isa Real && p isa Real
208-
thegrad = ifelse(!iszero(x) | (p<0), (p * y / x),
209-
ifelse(isone(p), one(y),
210-
ifelse(0<p<1, oftype(y, Inf), zero(y) )))
211-
else
212-
thegrad = (p * y / x)
213-
end
214-
dx = project_x(conj(thegrad) * dy)
205+
dx = project_x(conj(_pow_grad_x(x,p,y)) * dy)
215206
dp = @thunk project_p(conj(y * log(complex(x))) * dy)
216207
return (NoTangent(), dx, dp)
217208
end
218209
return y, power_pullback
219210
end
211+
_pow_grad_x(x, p, y) = (p * y / x)
212+
function _pow_grad_x(x::Real, p::Real, y)
213+
return ifelse(!iszero(x) | (p<0), (p * y / x),
214+
ifelse(isone(p), one(y),
215+
ifelse(0<p<1, oftype(y, Inf), zero(y) )))
216+
end
220217

221218
@scalar_rule(
222219
rem(x, y),

test/rulesets/Base/fastmath_able.jl

Lines changed: 29 additions & 107 deletions
Original file line numberDiff line numberDiff line change
@@ -138,8 +138,6 @@ const FASTABLE_AST = quote
138138
end
139139

140140
@testset "$f(x::$T, y::$T) type check" for f in (/, +, -,\, hypot), T in (Float32, Float64)
141-
# ^ removed for now!
142-
143141
x, Δx, x̄ = 10rand(T, 3)
144142
y, Δy, ȳ = rand(T, 3)
145143
@assert T == typeof(f(x, y))
@@ -162,12 +160,14 @@ const FASTABLE_AST = quote
162160
end
163161

164162
@testset "^(x::$T, p::$S)" for T in (Float64, ComplexF64), S in (Float64, ComplexF64)
165-
# When both x & p are Real, and !(isinteger(p)),
166-
# then x must be positive to avoid a DomainError
167163
test_frule(^, rand(T) + 3, rand(T) + 3)
168164
test_rrule(^, rand(T) + 3, rand(T) + 3)
169-
165+
166+
# When both x & p are Real, and !(isinteger(p)),
167+
# then x must be positive to avoid a DomainError
170168
T <: Real && S <: Real && continue
169+
# In other cases, we can test values near zero:
170+
171171
test_frule(^, randn(T), rand(T))
172172
test_rrule(^, rand(T), rand(T))
173173
end
@@ -177,77 +177,13 @@ const FASTABLE_AST = quote
177177
# test_rrule(^, randn(T) + 3, p ⊢ NoTangent())
178178
# end
179179

180-
# @testset "^(x::Float64, p::$S) near x=0, p=1,0,-1,-2" for S in (Int, Float64)
181-
# # x^2. Easy to get NaN here by mistake.
182-
# p = S(+2)
183-
# @test frule((1,1,1), ^, 0.0, p)[1] == 0 # value
184-
# @test_broken frule((1,1,1), ^, 0.0, p)[2] == 0 # gradient, forwards
185-
# @test rrule(^, 0.0, p)[1] == 0 # value
186-
# @test unthunk(rrule(^, 0.0, p)[2](1.0)[2]) == 0 # gradient, reverse
187-
188-
# # Identity function x^1, at zero
189-
# p = S(+1)
190-
# @test frule((1,1,1), ^, 0.0, p)[1] == 0
191-
# @test_broken frule((1,1,1), ^, 0.0, p)[2] == 1
192-
# @test rrule(^, 0.0, p)[1] == 0
193-
# @test unthunk(rrule(^, 0.0, p)[2](1.0)[2]) == 1
194-
195-
# # Trivial singularity: 0^0 == 1 in Julia
196-
# p = S(0)
197-
# @test_skip frule((1,1,1), ^, 0.0, p)[1] == (0.0)^0
198-
# @test_broken frule((1,1,1), ^, 0.0, p)[2] == 0
199-
# @test_broken unthunk(rrule(^, 0.0, p)[2](1.0)[3]) == 0.0
200-
201-
# # Odd power, 1/x
202-
# p = S(-1)
203-
# @test_skip frule((1,1,1), ^, 0.0, p)[1] == (0.0)^-1
204-
# @test_broken frule((1,1,1), ^, 0.0, p)[2] == -Inf
205-
# @test_skip rrule(^, 0.0, p)[1] == (0.0)^-1 == Inf
206-
# @test unthunk(rrule(^, 0.0, p)[2](1.0)[2]) == -Inf
207-
208-
# @test_skip frule((1,1,1), ^, -0.0, p)[1] == (-0.0)^-1
209-
# @test_broken frule((1,1,1), ^, -0.0, p)[2] == -Inf
210-
# @test_skip rrule(^, -0.0, p)[1] == (-0.0)^-1 == -Inf
211-
# @test unthunk(rrule(^, -0.0, p)[2](1.0)[2]) == -Inf
212-
213-
# # Even power, 1/x^2
214-
# p = S(-2)
215-
# @test_skip frule((1,1,1), ^, 0.0, p)[1] == (0.0)^-2
216-
# @test_broken frule((1,1,1), ^, 0.0, p)[2] == -Inf
217-
# @test_skip rrule(^, 0.0, p)[1] == (0.0)^-2 == Inf
218-
# @test unthunk(rrule(^, 0.0, p)[2](1.0)[2]) == -Inf
219-
220-
# @test_skip frule((1,1,1), ^, -0.0, p)[1] == (-0.0)^-2
221-
# @test_broken frule((1,1,1), ^, -0.0, p)[2] == +Inf
222-
# @test_skip rrule(^, -0.0, p)[1] == (-0.0)^-2 == Inf
223-
# @test unthunk(rrule(^, -0.0, p)[2](1.0)[2]) == +Inf
224-
# end
225-
226-
# T <: Real && @testset "discontinuity for ^(x::Real, n::Int) when x ≤ 0" begin
227-
# # finite differences doesn't work for x < 0, so we check manually
228-
# x = -rand(T) .- 3
229-
# y = 3
230-
# Δx = randn(T)
231-
# Δy = randn(T)
232-
# Δz = randn(T)
233-
234-
# @test frule((ZeroTangent(), Δx, Δy), ^, x, y)[2] ≈ Δx * y * x^(y - 1)
235-
# @test frule((ZeroTangent(), Δx, Δy), ^, zero(x), y)[2] ≈ 0
236-
# _, ∂x, ∂y = rrule(^, x, y)[2](Δz)
237-
# @test ∂x ≈ Δz * y * x^(y - 1)
238-
# @test ∂y ≈ 0
239-
# _, ∂x, ∂y = rrule(^, zero(x), y)[2](Δz)
240-
# @test ∂x ≈ 0
241-
# @test ∂y ≈ 0
242-
# end
243-
# end
244-
end
180+
# Tests for power functions, at values near to zero.
245181

246182
POWERGRADS = [ # (x,p) => (dx,dp)
247-
# some regular points, sanity checks
183+
# Some regular points, sanity checks
248184
(1.0, 2) => (2.0, 0.0),
249185
(2.0, 2) => (4.0, 2.772588722239781),
250-
# at x=0, gradients for x seem clear,
186+
# At x=0, gradients for x seem clear,
251187
# for p I've just written here what it gives
252188
(0.0, 2) => (0.0, NaN),
253189
(-0.0, 2) => (-0.0, NaN),
@@ -259,74 +195,60 @@ POWERGRADS = [ # (x,p) => (dx,dp)
259195
(-0.0, -1) => (-Inf, Inf),
260196
(0.0, -2) => (-Inf, -Inf),
261197
(-0.0, -2) => (Inf, -Inf),
262-
# non-integer powers
198+
# Non-integer powers:
263199
(0.0, 0.5) => (Inf, NaN),
264200
(0.0, 3.5) => (0.0, NaN),
265-
266201
]
267-
for ((x,p), (gx, gp)) in POWERGRADS
202+
203+
for ((x,p), (gx, gp)) in POWERGRADS # power ^
268204
y = x^p
269205

206+
# Forward
270207
y_f = frule((1,1,1), ^, x, p)[1]
271208
isequal(y, y_f) || println("^ forward value for $x^$p: got $y_f, expected $y")
272209

273-
y_r = rrule(^, x, p)[1]
274-
isequal(y, y_r) || println("^ reverse value for $x^$p: got $y_r, expected $y")
275-
276210
gx_f = frule((0,1,0), ^, x, p)[1]
277211
gp_f = frule((0,0,1), ^, x, p)[2]
278212
# isequal(gx, gx_f) || println("^ forward `x` gradient for $x^$p: got $gx_f, expected $gx, maybe")
279213
# isequal(gp, gp_f) || println("^ forward `p` gradient for $x^$p: got $gp_f, expected $gp, maybe")
280214

215+
# Reverse
216+
y_r = rrule(^, x, p)[1]
217+
isequal(y, y_r) || println("^ reverse value for $x^$p: got $y_r, expected $y")
218+
281219
gx_r, gp_r = unthunk.(rrule(^, x, p)[2](1))[2:3]
282-
isequal(gx, gx_r) || println("^ reverse `x` gradient for $x^$p: got $gx_r, expected $gx")
220+
if x === -0.0 && p === 2
221+
@test 0.0 == gx_r # POWERGRADS says -0.0
222+
else
223+
isequal(gx, gx_r) || println("^ reverse `x` gradient for $x^$p: got $gx_r, expected $gx")
224+
end
283225
isequal(gp, gp_r) || println("^ reverse `p` gradient for $x^$p: got $gp_r, expected $gp")
284-
285226
end
286-
for ((x,p), (gx, gp)) in POWERGRADS
227+
228+
for ((x,p), (gx, gp)) in POWERGRADS # literal_pow
287229
p isa Int || continue
288230
x isa Real || continue
289231

290232
y = x^p
291233

234+
# Forward
292235
y_f = frule((1,1,1,1), Base.literal_pow, ^, x, Val(p))[1]
293236
isequal(y, y_f) || println("literal_pow forward value for $x^$p: got $y_f, expected $y")
294237

238+
gx_f = frule((0,0,1,0), Base.literal_pow, ^, x, Val(p))[1]
239+
# isequal(gx, gx_f) || println("literal_pow forward `x` gradient for $x^$p: got $gx_f, expected $gx, maybe, y=$y")
240+
241+
# Reverse
295242
y_r = rrule(Base.literal_pow, ^, x, Val(p))[1]
296243
isequal(y, y_r) || println("literal_pow reverse value for $x^$p: got $y_r, expected $y")
297244

298245
gx_r = unthunk(rrule(Base.literal_pow, ^, x, Val(p))[2](1))[3]
299246
isequal(gx, gx_r) || println("literal_pow `x` gradient for $x^$p: got $gx_r, expected $gx")
300247

301-
gx_f = frule((0,0,1,0), Base.literal_pow, ^, x, Val(p))[1]
302-
# isequal(gx, gx_f) || println("literal_pow forward `x` gradient for $x^$p: got $gx_f, expected $gx, maybe")
303-
end
304-
305-
306-
for x in Any[0.0, -0.0, 0.0+0im], p in Any[2, 1.5, 1, 0.5, 0, -0.5, -1, -1.5, -2]
307-
308-
y = x^p
309-
yr = rrule(^, x, p)[1]
310-
# isequal(y, yr) || printstyled("runtime $x^$p = $y, but rrule gives $yr \n", color=:red)
311-
312-
gx, gp = unthunk.(rrule(^, x, p)[2](1)[2:3])
313-
println("runtime $x^$p gradient from rrule: $gx, $gp")
314-
315-
p isa Int || continue # e.g. Meta.@lower x^5.0
316-
x isa Real || continue # limitation of methods here?
317-
y = Base.literal_pow(^, x, Val(p))
318-
319-
# yr = rrule(Base.literal_pow, ^, x, Val(p))[1]
320-
# isequal(y, yr) || printstyled("literal $x^$p = $y, but rrule gives $yr\n", color=:red)
321-
322-
# gx = unthunk(rrule(Base.literal_pow, ^, x, Val(p))[2](1))[3]
323-
# println("literal $x^$p gradient from rrule: $gx")
324-
325-
# gg[(x,p)] = (gx, nothing)
248+
# @info "all" x y p gx_f gx_r
326249
end
327250

328251

329-
330252
@testset "sign" begin
331253
@testset "real" begin
332254
@testset "at $x" for x in (-1.1, -1.1, 0.5, 100.0)

0 commit comments

Comments
 (0)