Skip to content

Commit 48a9d14

Browse files
committed
mostly delete literal_pow
1 parent 1b389c0 commit 48a9d14

File tree

3 files changed

+18
-81
lines changed

3 files changed

+18
-81
lines changed

src/rulesets/Base/base.jl

Lines changed: 14 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -180,26 +180,23 @@ end
180180
@scalar_rule ceil(x) zero(x)
181181

182182
# `literal_pow`
183+
# This is mostly handled by AD; it's a micro-optimisation to provide a gradient for x*x*x
183184
# Note that rules for `^` are defined in the fastmath_able.jl
184185

185-
function frule((_, _, Δx, _), ::typeof(Base.literal_pow), ::typeof(^), x::Real, ::Val{p}) where p
186-
y = Base.literal_pow(^, x, Val(p))
187-
yox = Base.literal_pow(^, x, Val(p-1))
188-
return y, p * yox * Δx
186+
function frule((_, _, Δx, _), ::typeof(Base.literal_pow), ::typeof(^), x::Real, ::Val{2})
187+
return x * x, 2 * x * Δx
188+
end
189+
function frule((_, _, Δx, _), ::typeof(Base.literal_pow), ::typeof(^), x::Real, ::Val{3})
190+
x2 = x * x
191+
return x2 * x, 3 * x2 * Δx
189192
end
190-
frule((_, _, Δx, _), ::typeof(Base.literal_pow), ::typeof(^), x::Real, ::Val{0}) = x^0, zero(Δx)
191193

192-
function rrule(::typeof(Base.literal_pow), ::typeof(^), x::Real, ::Val{p}) where p
193-
y = Base.literal_pow(^, x, Val(p))
194-
@inline function literal_pow_pullback(dy)
195-
# Calling literal_pow a 2nd time is the easy way to get all the edge cases right.
196-
# It should be cheap up to p=4, which is the main use of literal powers, right?
197-
yox = Base.literal_pow(^, x, Val(p-1))
198-
return (NoTangent(), NoTangent(), ProjectTo(x)(p * yox * dy), NoTangent())
199-
end
200-
return y, literal_pow_pullback
194+
function rrule(::typeof(Base.literal_pow), ::typeof(^), x::Real, ::Val{2})
195+
@inline pow2_pullback(dy) = (NoTangent(), NoTangent(), ProjectTo(x)(2 * x * dy), NoTangent())
196+
return x * x, pow2_pullback
201197
end
202-
function rrule(::typeof(Base.literal_pow), ::typeof(^), x::Real, ::Val{0})
203-
literal_pow_zero_pullback(dy) = (NoTangent(), NoTangent(), ProjectTo(x)(zero(dy)), NoTangent())
204-
return x^0, literal_pow_zero_pullback
198+
function rrule(::typeof(Base.literal_pow), ::typeof(^), x::Real, ::Val{3})
199+
x2 = x * x
200+
@inline pow3_pullback(dy) = (NoTangent(), NoTangent(), ProjectTo(x)(3 * x2 * dy), NoTangent())
201+
return x2 * x, pow3_pullback
205202
end

test/rulesets/Base/base.jl

Lines changed: 4 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -188,46 +188,10 @@
188188
@test rrule(Base.depwarn, "message", :f) !== nothing
189189
end
190190

191-
@testset "literal_pow" begin
192-
test_frule(Base.literal_pow, ^, 3.5, Val(3))
193-
test_rrule(Base.literal_pow, ^, 3.5, Val(3))
194-
195-
@testset "regular: $x^$p" for x in [-1.5, 0.0, 3.5], p in [-3, -1, 0, 1, 3]
196-
x == 0 && p < 0 && continue
197-
test_frule(Base.literal_pow, ^, x, Val(p))
198-
test_rrule(Base.literal_pow, ^, x, Val(p))
199-
end
200-
201-
@testset "singularities: 0^0, 0^-1, 0^-2" begin
202-
# Trivial one: 0^0 == 1 in Julia
203-
@test frule((1,1,1,1), Base.literal_pow, ^, 0.0, Val(0)) == ((0.0)^0, 0)
204-
@test rrule(Base.literal_pow, ^, 0.0, Val(0))[2](1.0)[3] == 0.0
205-
206-
# Odd power, 1/x
207-
@test frule((1,1,1,1), Base.literal_pow, ^, 0.0, Val(-1)) == ((0.0)^-1, -Inf)
208-
@test rrule(Base.literal_pow, ^, 0.0, Val(-1))[1] == (0.0)^-1 == Inf
209-
@test rrule(Base.literal_pow, ^, 0.0, Val(-1))[2](1.0)[3] == -Inf
210-
211-
@test frule((1,1,1,1), Base.literal_pow, ^, -0.0, Val(-1)) == ((-0.0)^-1, -Inf)
212-
@test rrule(Base.literal_pow, ^, -0.0, Val(-1))[1] == (-0.0)^-1 == -Inf
213-
@test rrule(Base.literal_pow, ^, -0.0, Val(-1))[2](1.0)[3] == -Inf
214-
215-
# Even power, 1/x^2
216-
@test frule((1,1,1,1), Base.literal_pow, ^, 0.0, Val(-2)) == ((0.0)^-2, -Inf)
217-
@test rrule(Base.literal_pow, ^, 0.0, Val(-2))[1] == (0.0)^-2 == Inf
218-
@test rrule(Base.literal_pow, ^, 0.0, Val(-2))[2](1.0)[3] == -Inf
219-
220-
@test frule((1,1,1,1), Base.literal_pow, ^, -0.0, Val(-2)) == ((-0.0)^-2, +Inf)
221-
@test rrule(Base.literal_pow, ^, -0.0, Val(-2))[1] == (-0.0)^-2 == Inf
222-
@test rrule(Base.literal_pow, ^, -0.0, Val(-2))[2](1.0)[3] == +Inf
223-
224-
# Not singluar, but ^ messed these up: x^1 and x^2
225-
@test frule((1,1,1,1), Base.literal_pow, ^, 0.0, Val(2)) == (0.0, 0)
226-
@test rrule(Base.literal_pow, ^, 0.0, Val(2))[2](1.0)[3] == 0.0
227-
228-
@test frule((1,1,1,1), Base.literal_pow, ^, 0.0, Val(1)) == (0.0, 1)
229-
@test rrule(Base.literal_pow, ^, 0.0, Val(1))[2](1.0)[3] == 1.0
230-
end
191+
@testset "literal_pow: $x^$p" for x in [-1.5, 0.0, 3.5], p in [2, 3]
192+
x == 0 && p < 0 && continue
193+
test_frule(Base.literal_pow, ^, x, Val(p))
194+
test_rrule(Base.literal_pow, ^, x, Val(p))
231195
end
232196

233197
@testset "Float conversions" begin

test/rulesets/Base/fastmath_able.jl

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -172,13 +172,7 @@ const FASTABLE_AST = quote
172172
test_rrule(^, rand(T), rand(T))
173173
end
174174

175-
# @testset "^(x::$T, $p::Int)" for T in (Float64, ComplexF64), p in -2:2
176-
# test_frule(^, randn(T) + 3, p ⊢ NoTangent()) # this doesn't just skip p's tangent
177-
# test_rrule(^, randn(T) + 3, p ⊢ NoTangent())
178-
# end
179-
180175
# Tests for power functions, at values near to zero.
181-
182176
POWERGRADS = [ # (x,p) => (dx,dp)
183177
# Some regular points, as sanity checks:
184178
(1.0, 2) => (2.0, 0.0),
@@ -235,24 +229,6 @@ const FASTABLE_AST = quote
235229
end
236230
@test isequal(∂p, ∂p_rev) # || println("^ reverse `p` gradient for $x^$p: got $∂p_rev, expected $∂p")
237231
end
238-
239-
@testset "literal_pow $x ^ $p" for ((x,p), (∂x, ∂p)) in POWERGRADS
240-
y = Base.literal_pow(^, x, Val(p))
241-
242-
# Forward
243-
y_fwd = frule((1,1,1,1), Base.literal_pow, ^, x, Val(p))[1]
244-
@test y === y_fwd # || println("literal_pow forward value for $x^$p: got $y_fwd, expected $y")
245-
246-
∂x_fwd = frule((0,0,1,0), Base.literal_pow, ^, x, Val(p))[1]
247-
# isequal(∂x, ∂x_fwd) || println("literal_pow forward `x` gradient for $x^$p: got $∂x_fwd, expected $∂x, maybe, y=$y")
248-
249-
# Reverse
250-
y_rev = rrule(Base.literal_pow, ^, x, Val(p))[1]
251-
@test y === y_rev # || println("literal_pow reverse value for $x^$p: got $y_rev, expected $y")
252-
253-
∂x_rev = unthunk(rrule(Base.literal_pow, ^, x, Val(p))[2](1))[3]
254-
@test isequal(∂x, ∂x_rev) # || println("literal_pow `x` gradient for $x^$p: got $∂x_rev, expected $∂x")
255-
end
256232
end
257233

258234
@testset "sign" begin

0 commit comments

Comments
 (0)