Skip to content

Commit 415c8fb

Browse files
simeonschaubmcabbott
authored andcommitted
fix some issues with rules for ^
should fix JuliaDiff/Diffractor.jl#26
1 parent a130b8f commit 415c8fb

File tree

4 files changed

+35
-8
lines changed

4 files changed

+35
-8
lines changed

src/rulesets/Base/base.jl

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -182,11 +182,18 @@ end
182182
# note: rules for ^ are defined in the fastmath_able.jl
183183
function frule((_, _, Δx, _), ::typeof(Base.literal_pow), ::typeof(^), x::Real, pv::Val{p}) where p
184184
y = Base.literal_pow(^, x, pv)
185-
return y, (p * y / x * Δx)
185+
return y, ifelse(iszero(x), zero(y), p * y / x * Δx)
186186
end
187+
frule((_, _, Δx, _), ::typeof(Base.literal_pow), ::typeof(^), x::Real, ::Val{1}) = x^1, Δx
187188

188189
function rrule(::typeof(Base.literal_pow), ::typeof(^), x::Real, pv::Val{p}) where p
189190
y = Base.literal_pow(^, x, pv)
190-
literal_pow_pullback(dy) = NoTangent(), NoTangent(), (p * y / x * dy), NoTangent()
191+
function literal_pow_pullback(dy)
192+
return NoTangent(), NoTangent(), ifelse(iszero(x), zero(y), p * y / x * dy), NoTangent()
193+
end
191194
return y, literal_pow_pullback
192195
end
196+
function rrule(::typeof(Base.literal_pow), ::typeof(^), x::Real, pv::Val{1})
197+
literal_pow_one_pullback(dy) = NoTangent(), NoTangent(), dy, NoTangent()
198+
return x^1, literal_pow_one_pullback
199+
end

src/rulesets/Base/fastmath_able.jl

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -164,14 +164,16 @@ let
164164
@scalar_rule x - y (true, -1)
165165
@scalar_rule x / y (one(x) / y, -/ y))
166166
#log(complex(x)) is required so it gives correct complex answer for x<0
167-
@scalar_rule(x ^ y,
168-
(ifelse(iszero(x), zero(Ω), y * Ω / x), Ω * log(complex(x))),
169-
)
167+
@scalar_rule(x ^ y, (
168+
ifelse(iszero(x), ifelse(isone(y), one(Ω), zero(Ω)), y * Ω / x),
169+
Ω * log(complex(x)),
170+
))
170171
# x^y for x < 0 errors when y is not an integer, but then derivative wrt y
171172
# is undefined, so we adopt subgradient convention and set derivative to 0.
172-
@scalar_rule(x::Real ^ y::Real,
173-
(ifelse(iszero(x), zero(Ω), y * Ω / x), Ω * log(oftype(Ω, ifelse(x 0, one(x), x)))),
174-
)
173+
@scalar_rule(x::Real ^ y::Real, (
174+
ifelse(iszero(x), ifelse(isone(y), one(Ω), zero(Ω)), y * Ω / x),
175+
Ω * log(oftype(Ω, ifelse(x 0, one(x), x))),
176+
))
175177
@scalar_rule(
176178
rem(x, y),
177179
@setup((u, nan) = promote(x / y, NaN16), isint = isinteger(x / y)),

test/rulesets/Base/base.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,15 @@
192192
# for real x and n, x must be >0
193193
test_frule(Base.literal_pow, ^, 3.5, Val(3))
194194
test_rrule(Base.literal_pow, ^, 3.5, Val(3))
195+
196+
test_frule(Base.literal_pow, ^, 0.0, Val(3))
197+
test_rrule(Base.literal_pow, ^, 0.0, Val(3))
198+
199+
test_frule(Base.literal_pow, ^, 3.5, Val(1))
200+
test_rrule(Base.literal_pow, ^, 3.5, Val(1))
201+
202+
test_frule(Base.literal_pow, ^, 0.0, Val(1))
203+
test_rrule(Base.literal_pow, ^, 0.0, Val(1))
195204
end
196205

197206
@testset "Float conversions" begin

test/rulesets/Base/fastmath_able.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,15 @@ const FASTABLE_AST = quote
182182
@test ∂y 0
183183
end
184184
end
185+
186+
@testset "edge cases with ^" begin
187+
# FIXME
188+
@test_skip test_frule(^, 0.0, rand() + 3 NoTangent(); fdm=forward_fdm(5,1))
189+
test_rrule(^, 0.0, rand() + 3; fdm=forward_fdm(5,1))
190+
191+
test_frule(^, 0.0, 1.0 NoTangent(); fdm=forward_fdm(5,1))
192+
test_rrule(^, 0.0, 1.0; fdm=forward_fdm(5,1))
193+
end
185194
end
186195

187196
@testset "sign" begin

0 commit comments

Comments
 (0)