Skip to content

Commit ab07ced

Browse files
authored
Merge branch 'master' into mz/optout
2 parents ef4bfed + c637457 commit ab07ced

File tree

4 files changed

+160
-48
lines changed

4 files changed

+160
-48
lines changed

src/rulesets/Base/base.jl

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -179,14 +179,24 @@ end
179179
@scalar_rule floor(x) zero(x)
180180
@scalar_rule ceil(x) zero(x)
181181

182-
# note: rules for ^ are defined in the fastmath_able.jl
183-
function frule((_, _, Δx, _), ::typeof(Base.literal_pow), ::typeof(^), x::Real, pv::Val{p}) where p
184-
y = Base.literal_pow(^, x, pv)
185-
return y, (p * y / x * Δx)
182+
# `literal_pow`
183+
# This is mostly handled by AD; it's a micro-optimisation to provide a gradient for x*x*x
184+
# Note that rules for `^` are defined in the fastmath_able.jl
185+
186+
function frule((_, _, Δx, _), ::typeof(Base.literal_pow), ::typeof(^), x::Real, ::Val{2})
187+
return x * x, 2 * x * Δx
188+
end
189+
function frule((_, _, Δx, _), ::typeof(Base.literal_pow), ::typeof(^), x::Real, ::Val{3})
190+
x2 = x * x
191+
return x2 * x, 3 * x2 * Δx
186192
end
187193

188-
function rrule(::typeof(Base.literal_pow), ::typeof(^), x::Real, pv::Val{p}) where p
189-
y = Base.literal_pow(^, x, pv)
190-
literal_pow_pullback(dy) = NoTangent(), NoTangent(), (p * y / x * dy), NoTangent()
191-
return y, literal_pow_pullback
194+
function rrule(::typeof(Base.literal_pow), ::typeof(^), x::Real, ::Val{2})
195+
square_pullback(dy) = (NoTangent(), NoTangent(), ProjectTo(x)(2 * x * dy), NoTangent())
196+
return x * x, square_pullback
197+
end
198+
function rrule(::typeof(Base.literal_pow), ::typeof(^), x::Real, ::Val{3})
199+
x2 = x * x
200+
cube_pullback(dy) = (NoTangent(), NoTangent(), ProjectTo(x)(3 * x2 * dy), NoTangent())
201+
return x2 * x, cube_pullback
192202
end

src/rulesets/Base/fastmath_able.jl

Lines changed: 66 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ let
5252
# exponents
5353
@scalar_rule cbrt(x) inv(3 * Ω ^ 2)
5454
@scalar_rule inv(x) -^ 2)
55-
@scalar_rule sqrt(x) inv(2Ω)
55+
@scalar_rule sqrt(x) inv(2Ω) # gradient +Inf at x==0
5656
@scalar_rule exp(x) Ω
5757
@scalar_rule exp10(x) Ω * log(oftype(x, 10))
5858
@scalar_rule exp2(x) Ω * log(oftype(x, 2))
@@ -137,8 +137,7 @@ let
137137

138138
# Binary functions
139139

140-
# `hypot`
141-
140+
## `hypot`
142141
function frule(
143142
(_, Δx, Δy),
144143
::typeof(hypot),
@@ -163,29 +162,53 @@ let
163162
@scalar_rule x + y (true, true)
164163
@scalar_rule x - y (true, -1)
165164
@scalar_rule x / y (one(x) / y, -/ y))
166-
#log(complex(x)) is required so it gives correct complex answer for x<0
167-
@scalar_rule(x ^ y,
168-
(ifelse(iszero(x), zero(Ω), y * Ω / x), Ω * log(complex(x))),
169-
)
170-
# x^y for x < 0 errors when y is not an integer, but then derivative wrt y
171-
# is undefined, so we adopt subgradient convention and set derivative to 0.
172-
@scalar_rule(x::Real ^ y::Real,
173-
(ifelse(iszero(x), zero(Ω), y * Ω / x), Ω * log(oftype(Ω, ifelse(x 0, one(x), x)))),
174-
)
165+
166+
## power
167+
# literal_pow is in base.jl
168+
function frule((_, Δx, Δp), ::typeof(^), x::Number, p::Number)
169+
y = x ^ p
170+
_dx = _pow_grad_x(x, p, float(y))
171+
if iszero(Δp)
172+
# Treat this as a strong zero, to avoid NaN, and save the cost of log
173+
return y, _dx * Δx
174+
else
175+
# This may do real(log(complex(...))) which matches ProjectTo in rrule
176+
_dp = _pow_grad_p(x, p, float(y))
177+
return y, muladd(_dp, Δp, _dx * Δx)
178+
end
179+
end
180+
181+
function rrule(::typeof(^), x::Number, p::Number)
182+
y = x^p
183+
project_x = ProjectTo(x)
184+
project_p = ProjectTo(p)
185+
function power_pullback(dy)
186+
_dx = _pow_grad_x(x, p, float(y))
187+
return (
188+
NoTangent(),
189+
project_x(conj(_dx) * dy),
190+
# _pow_grad_p contains log, perhaps worth thunking:
191+
@thunk project_p(conj(_pow_grad_p(x, p, float(y))) * dy)
192+
)
193+
end
194+
return y, power_pullback
195+
end
196+
197+
## `rem`
175198
@scalar_rule(
176199
rem(x, y),
177200
@setup((u, nan) = promote(x / y, NaN16), isint = isinteger(x / y)),
178201
(ifelse(isint, nan, one(u)), ifelse(isint, nan, -trunc(u))),
179202
)
203+
## `min`, `max`
180204
@scalar_rule max(x, y) @setup(gt = x > y) (gt, !gt)
181205
@scalar_rule min(x, y) @setup(gt = x > y) (!gt, gt)
182206

183207
# Unary functions
184208
@scalar_rule +x true
185209
@scalar_rule -x -1
186210

187-
# `sign`
188-
211+
## `sign`
189212
function frule((_, Δx), ::typeof(sign), x)
190213
n = ifelse(iszero(x), one(real(x)), abs(x))
191214
Ω = x isa Real ? sign(x) : x / n
@@ -237,9 +260,38 @@ let
237260
"Non-FastMath compatible rules defined in fastmath_able.jl. \n Definitions:\n" *
238261
join(non_transformed_definitions, "\n")
239262
)
263+
# This error() may not play well with Revise. But a wanring @error does:
264+
# @error "Non-FastMath compatible rules defined in fastmath_able.jl." non_transformed_definitions
240265
end
241266

242267
eval(fast_ast)
243268
eval(fastable_ast) # Get original definitions
244269
# we do this second so it overwrites anything we included by mistake in the fastable
245270
end
271+
272+
## power
273+
# Thes functions need to be defined outside the eval() block.
274+
# The special cases they aim to hit are in POWERGRADS in tests.
275+
_pow_grad_x(x, p, y) = (p * y / x)
276+
function _pow_grad_x(x::Real, p::Real, y)
277+
return if !iszero(x) || p < 0
278+
p * y / x
279+
elseif isone(p)
280+
one(y)
281+
elseif iszero(p) || p > 1
282+
zero(y)
283+
else
284+
oftype(y, Inf)
285+
end
286+
end
287+
288+
_pow_grad_p(x, p, y) = y * log(complex(x))
289+
function _pow_grad_p(x::Real, p::Real, y)
290+
return if !iszero(x)
291+
y * real(log(complex(x)))
292+
elseif p > 0
293+
zero(y)
294+
else
295+
oftype(y, NaN)
296+
end
297+
end

test/rulesets/Base/base.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -188,10 +188,10 @@
188188
@test rrule(Base.depwarn, "message", :f) !== nothing
189189
end
190190

191-
@testset "literal_pow" begin
192-
# for real x and n, x must be >0
193-
test_frule(Base.literal_pow, ^, 3.5, Val(3))
194-
test_rrule(Base.literal_pow, ^, 3.5, Val(3))
191+
@testset "literal_pow: $x^$p" for x in [-1.5, 0.0, 3.5], p in [2, 3]
192+
x == 0 && p < 0 && continue
193+
test_frule(Base.literal_pow, ^, x, Val(p))
194+
test_rrule(Base.literal_pow, ^, x, Val(p))
195195
end
196196

197197
@testset "Float conversions" begin

test/rulesets/Base/fastmath_able.jl

Lines changed: 72 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ const FASTABLE_AST = quote
137137
test_rrule(f, 10rand(T), rand(T))
138138
end
139139

140-
@testset "$f(x::$T, y::$T) type check" for f in (/, +, -,\, hypot, ^), T in (Float32, Float64)
140+
@testset "$f(x::$T, y::$T) type check" for f in (/, +, -,\, hypot), T in (Float32, Float64)
141141
x, Δx, x̄ = 10rand(T, 3)
142142
y, Δy, ȳ = rand(T, 3)
143143
@assert T == typeof(f(x, y))
@@ -159,28 +159,78 @@ const FASTABLE_AST = quote
159159
end
160160
end
161161

162-
@testset "^(x::$T, n::$T)" for T in (Float64, ComplexF64)
163-
# for real x and n, x must be >0
164-
test_frule(^, rand(T) + 3, rand(T) + 3)
165-
test_rrule(^, rand(T) + 3, rand(T) + 3)
166-
167-
T <: Real && @testset "discontinuity for ^(x::Real, n::Int) when x ≤ 0" begin
168-
# finite differences doesn't work for x < 0, so we check manually
169-
x = -rand(T) .- 3
170-
y = 3
171-
Δx = randn(T)
172-
Δy = randn(T)
173-
Δz = randn(T)
174-
175-
@test frule((ZeroTangent(), Δx, Δy), ^, x, y)[2] Δx * y * x^(y - 1)
176-
@test frule((ZeroTangent(), Δx, Δy), ^, zero(x), y)[2] 0
177-
_, ∂x, ∂y = rrule(^, x, y)[2](Δz)
178-
@test ∂x Δz * y * x^(y - 1)
179-
@test ∂y 0
180-
_, ∂x, ∂y = rrule(^, zero(x), y)[2](Δz)
181-
@test ∂x 0
182-
@test ∂y 0
162+
@testset "^(x::$T, p::$S)" for T in (Float64, ComplexF64), S in (Float64, ComplexF64)
163+
test_frule(^, rand(T) + 3, rand(S) + 3)
164+
test_rrule(^, rand(T) + 3, rand(S) + 3)
165+
166+
# When both x & p are Real, and !(isinteger(p)),
167+
# then x must be positive to avoid a DomainError
168+
T <: Real && S <: Real && continue
169+
# In other cases, we can test values near zero:
170+
171+
test_frule(^, randn(T), rand(S))
172+
test_rrule(^, rand(T), rand(S))
173+
end
174+
175+
# Tests for power functions, at values near to zero.
176+
POWERGRADS = [ # (x,p) => (dx,dp)
177+
# Some regular points, as sanity checks:
178+
(1.0, 2) => (2.0, 0.0),
179+
(2.0, 2) => (4.0, 2.772588722239781),
180+
# At x=0, gradients for x seem clear,
181+
# for p less certain what's best.
182+
(0.0, 2) => (0.0, 0.0),
183+
(-0.0, 2) => (0.0, 0.0), # probably (-0.0, 0.0) would be ideal
184+
(0.0, 1) => (1.0, 0.0),
185+
(-0.0, 1) => (1.0, 0.0),
186+
(0.0, 0) => (0.0, NaN),
187+
(-0.0, 0) => (0.0, NaN),
188+
(0.0, -1) => (-Inf, NaN),
189+
(-0.0, -1) => (-Inf, NaN),
190+
(0.0, -2) => (-Inf, NaN),
191+
(-0.0, -2) => (Inf, NaN),
192+
# Integer x & p, check no InexactErrors
193+
(0, 2) => (0.0, 0.0),
194+
(0, 1) => (1.0, 0.0),
195+
(0, 0) => (0.0, NaN),
196+
(0, -1) => (-Inf, NaN),
197+
(0, -2) => (-Inf, NaN),
198+
# Non-integer powers:
199+
(0.0, 0.5) => (Inf, 0.0),
200+
(0.0, 3.5) => (0.0, 0.0),
201+
(0.0, -1.5) => (-Inf, NaN),
202+
]
203+
204+
@testset "$x ^ $p" for ((x,p), (∂x, ∂p)) in POWERGRADS
205+
if x isa Integer && p isa Integer && p < 0
206+
@test_throws DomainError x^p
207+
continue
183208
end
209+
y = x^p
210+
211+
# Forward
212+
y_fwd = frule((1,1,1), ^, x, p)[1]
213+
@test isequal(y, y_fwd)
214+
215+
∂x_fwd = frule((0,1,0), ^, x, p)[2]
216+
∂p_fwd = frule((0,0,1), ^, x, p)[2]
217+
@test isequal(∂x, ∂x_fwd)
218+
if x===0.0 && p===0.5
219+
@test_broken isequal(∂p, ∂p_fwd)
220+
else
221+
@test isequal(∂p, ∂p_fwd)
222+
end
223+
224+
∂x_fwd = frule((0,1,ZeroTangent()), ^, x, p)[2] # easier, strong zero
225+
@test isequal(∂x, ∂x_fwd)
226+
227+
# Reverse
228+
y_rev = rrule(^, x, p)[1]
229+
@test isequal(y, y_rev)
230+
231+
∂x_rev, ∂p_rev = unthunk.(rrule(^, x, p)[2](1))[2:3]
232+
@test isequal(∂x, ∂x_rev)
233+
@test isequal(∂p, ∂p_rev)
184234
end
185235
end
186236

0 commit comments

Comments
 (0)