Skip to content

Commit 35e1a6b

Browse files
committed
fixup literal_pow, plus tests
1 parent 415c8fb commit 35e1a6b

File tree

2 files changed

+65
-18
lines changed

2 files changed

+65
-18
lines changed

src/rulesets/Base/base.jl

Lines changed: 37 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -179,21 +179,48 @@ end
179179
@scalar_rule floor(x) zero(x)
180180
@scalar_rule ceil(x) zero(x)
181181

182-
# note: rules for ^ are defined in the fastmath_able.jl
183-
function frule((_, _, Δx, _), ::typeof(Base.literal_pow), ::typeof(^), x::Real, pv::Val{p}) where p
184-
y = Base.literal_pow(^, x, pv)
185-
return y, ifelse(iszero(x), zero(y), p * y / x * Δx)
182+
# `literal_pow`
183+
# Note that rules for `^` are defined in the fastmath_able.jl
184+
185+
function frule((_, _, Δx, _), ::typeof(Base.literal_pow), ::typeof(^), x::Real, ::Val{p}) where p
186+
yox = Base.literal_pow(^, x, Val(p-1))
187+
if p < 0 && iseven(p)
188+
# When p<0 and x==0, using yox * x for the primal gives NaN instead of +-Inf
189+
y = ifelse(iszero(x), oftype(yox, Inf), yox * x)
190+
elseif p < 0
191+
y = ifelse(iszero(x), copysign(oftype(yox, Inf), x), yox * x)
192+
else
193+
y = yox * x
194+
end
195+
return y, p * yox * Δx
186196
end
187197
frule((_, _, Δx, _), ::typeof(Base.literal_pow), ::typeof(^), x::Real, ::Val{1}) = x^1, Δx
198+
frule((_, _, Δx, _), ::typeof(Base.literal_pow), ::typeof(^), x::Real, ::Val{0}) = x^0, zero(Δx)
188199

189-
function rrule(::typeof(Base.literal_pow), ::typeof(^), x::Real, pv::Val{p}) where p
190-
y = Base.literal_pow(^, x, pv)
191-
function literal_pow_pullback(dy)
192-
return NoTangent(), NoTangent(), ifelse(iszero(x), zero(y), p * y / x * dy), NoTangent()
200+
function rrule(::typeof(Base.literal_pow), ::typeof(^), x::Real, ::Val{p}) where p
201+
yox = Base.literal_pow(^, x, Val(p-1))
202+
project = ProjectTo(x)
203+
@inline function literal_pow_pullback(dy)
204+
return NoTangent(), NoTangent(), project(p * yox * dy), NoTangent()
205+
end
206+
if p < 0 && iseven(p)
207+
# When p<0 and x==0, using yox * x for the primal gives NaN instead of +-Inf
208+
y = ifelse(iszero(x), oftype(yox, Inf), yox * x)
209+
elseif p < 0
210+
y = ifelse(iszero(x), copysign(oftype(yox, Inf), x), yox * x)
211+
else
212+
y = yox * x
193213
end
194214
return y, literal_pow_pullback
195215
end
196-
function rrule(::typeof(Base.literal_pow), ::typeof(^), x::Real, pv::Val{1})
197-
literal_pow_one_pullback(dy) = NoTangent(), NoTangent(), dy, NoTangent()
216+
function rrule(::typeof(Base.literal_pow), ::typeof(^), x::Real, ::Val{1})
217+
project = ProjectTo(x)
218+
literal_pow_one_pullback(dy) = NoTangent(), NoTangent(), project(dy), NoTangent()
198219
return x^1, literal_pow_one_pullback
199220
end
221+
function rrule(::typeof(Base.literal_pow), ::typeof(^), x::Real, ::Val{0})
222+
# Since 0^0 == 1 == 0.001^0, this gradient should not be NaN at x==0
223+
project = ProjectTo(x)
224+
literal_pow_zero_pullback(dy) = NoTangent(), NoTangent(), project(zero(dy)), NoTangent()
225+
return x^0, literal_pow_zero_pullback
226+
end

test/rulesets/Base/base.jl

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -189,18 +189,38 @@
189189
end
190190

191191
@testset "literal_pow" begin
192-
# for real x and n, x must be >0
193192
test_frule(Base.literal_pow, ^, 3.5, Val(3))
194193
test_rrule(Base.literal_pow, ^, 3.5, Val(3))
195194

196-
test_frule(Base.literal_pow, ^, 0.0, Val(3))
197-
test_rrule(Base.literal_pow, ^, 0.0, Val(3))
198-
199-
test_frule(Base.literal_pow, ^, 3.5, Val(1))
200-
test_rrule(Base.literal_pow, ^, 3.5, Val(1))
195+
@testset "$x^$p" for x in [-1.5, 0.0, 3.5], p in [-3, -1, 0, 1, 3]
196+
x == 0 && p < 0 && continue
197+
test_frule(Base.literal_pow, ^, -1.5, Val(3))
198+
test_rrule(Base.literal_pow, ^, -1.5, Val(3))
199+
end
201200

202-
test_frule(Base.literal_pow, ^, 0.0, Val(1))
203-
test_rrule(Base.literal_pow, ^, 0.0, Val(1))
201+
@testset "singularities" begin
202+
# Trivial one: 0^0 == 1 in Julia
203+
@test frule((1,1,1,1), Base.literal_pow, ^, 0.0, Val(0)) == ((0.0)^0, 0)
204+
@test rrule(Base.literal_pow, ^, 0.0, Val(0))[2](1.0)[3] == 0.0
205+
206+
# Odd power, 1/x
207+
@test frule((1,1,1,1), Base.literal_pow, ^, 0.0, Val(-1)) == ((0.0)^-1, -Inf)
208+
@test rrule(Base.literal_pow, ^, 0.0, Val(-1))[1] == (0.0)^-1 == Inf
209+
@test rrule(Base.literal_pow, ^, 0.0, Val(-1))[2](1.0)[3] == -Inf
210+
211+
@test frule((1,1,1,1), Base.literal_pow, ^, -0.0, Val(-1)) == ((-0.0)^-1, -Inf)
212+
@test rrule(Base.literal_pow, ^, -0.0, Val(-1))[1] == (-0.0)^-1 == -Inf
213+
@test rrule(Base.literal_pow, ^, -0.0, Val(-1))[2](1.0)[3] == -Inf
214+
215+
# Even power, 1/x^2
216+
@test frule((1,1,1,1), Base.literal_pow, ^, 0.0, Val(-2)) == ((0.0)^-2, -Inf)
217+
@test rrule(Base.literal_pow, ^, 0.0, Val(-2))[1] == (0.0)^-2 == Inf
218+
@test rrule(Base.literal_pow, ^, 0.0, Val(-2))[2](1.0)[3] == -Inf
219+
220+
@test frule((1,1,1,1), Base.literal_pow, ^, -0.0, Val(-2)) == ((-0.0)^-2, +Inf)
221+
@test rrule(Base.literal_pow, ^, -0.0, Val(-2))[1] == (-0.0)^-2 == Inf
222+
@test rrule(Base.literal_pow, ^, -0.0, Val(-2))[2](1.0)[3] == +Inf
223+
end
204224
end
205225

206226
@testset "Float conversions" begin

0 commit comments

Comments
 (0)