Skip to content

Commit 2e6491c

Browse files
authored
Merge pull request #503 from JuliaDiff/ox/onlyfast
copysign and sincospi are not fastmath-able
2 parents b4dee46 + c3f06a8 commit 2e6491c

File tree

5 files changed

+30
-25
lines changed

5 files changed

+30
-25
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ChainRules"
22
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
3-
version = "1.8.0"
3+
version = "1.8.1"
44

55
[deps]
66
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

src/rulesets/Base/base.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
# See also fastmath_able.jl for where rules are defined simple base functions
22
# that also have FastMath versions.
33

4+
@scalar_rule copysign(y, x) (ifelse(signbit(x)!=signbit(y), -one(y), +one(y)), NoTangent())
5+
46
@scalar_rule one(x) zero(x)
57
@scalar_rule zero(x) zero(x)
68
@scalar_rule transpose(x) true
@@ -145,6 +147,11 @@ end
145147

146148
@scalar_rule sinc(x) cosc(x)
147149

150+
# the position of the minus sign below warrants the correct type for π
151+
if VERSION v"1.6"
152+
@scalar_rule sincospi(x) @setup((sinpix, cospix) = Ω) (π * cospix) (π * (-sinpix))
153+
end
154+
148155
@scalar_rule(
149156
clamp(x, low, high),
150157
@setup(

src/rulesets/Base/fastmath_able.jl

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -48,10 +48,6 @@ let
4848
# Trig-Multivariate
4949
@scalar_rule atan(y, x) @setup(u = x ^ 2 + y ^ 2) (x / u, -y / u)
5050
@scalar_rule sincos(x) @setup((sinx, cosx) = Ω) cosx -sinx
51-
# the position of the minus sign below warrants the correct type for π
52-
if VERSION v"1.6"
53-
@scalar_rule sincospi(x) @setup((sinpix, cospix) = Ω) (π * cospix) (π * (-sinpix))
54-
end
5551

5652
# exponents
5753
@scalar_rule cbrt(x) inv(3 * Ω ^ 2)
@@ -184,8 +180,6 @@ let
184180
@scalar_rule max(x, y) @setup(gt = x > y) (gt, !gt)
185181
@scalar_rule min(x, y) @setup(gt = x > y) (!gt, gt)
186182

187-
@scalar_rule copysign(y, x) (ifelse(signbit(x)!=signbit(y), -one(y), +one(y)), NoTangent())
188-
189183
# Unary functions
190184
@scalar_rule +x true
191185
@scalar_rule -x -1
@@ -233,6 +227,9 @@ let
233227
fast_ast = Base.FastMath.make_fastmath(fastable_ast)
234228

235229
# Guard against mistakenly defining something as fast-able when it isn't.
230+
# NOTE: this check is not infallible, it will be tricked if a function itself is not
231+
# fastmath_able but it's pullback uses something that is. So manual check should also be
232+
# done.
236233
non_transformed_definitions = intersect(fastable_ast.args, fast_ast.args)
237234
filter!(expr->!(expr isa LineNumberNode), non_transformed_definitions)
238235
if !isempty(non_transformed_definitions)

test/rulesets/Base/base.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,14 @@
11
@testset "base" begin
2+
@testset "copysign" begin
3+
# don't go too close to zero as the numerics may jump over it yielding wrong results
4+
@testset "at $y" for y in (-1.1, 0.1, 100.0)
5+
@testset "at $x" for x in (-1.1, -0.1, 33.0)
6+
test_frule(copysign, y, x)
7+
test_rrule(copysign, y, x)
8+
end
9+
end
10+
end
11+
212
@testset "Trig" begin
313
@testset "Basics" for x = (Float64(π)-0.01, Complex(π, π/2))
414
test_scalar(sec, x)
@@ -47,6 +57,15 @@
4757
@testset "sinc" for x = (0.0, 0.434, Complex(0.434, 0.25))
4858
test_scalar(sinc, x)
4959
end
60+
61+
if VERSION v"1.6"
62+
@testset "sincospi" for T in (Float64, ComplexF64)
63+
Δz = Tangent{Tuple{T,T}}(randn(T), randn(T))
64+
65+
test_frule(sincospi, randn(T))
66+
test_rrule(sincospi, randn(T); output_tangent=Δz)
67+
end
68+
end
5069
end # Trig
5170

5271
@testset "Angles" begin

test/rulesets/Base/fastmath_able.jl

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -64,14 +64,6 @@ const FASTABLE_AST = quote
6464
test_frule(sincos, randn(T))
6565
test_rrule(sincos, randn(T); output_tangent=Δz)
6666
end
67-
if VERSION v"1.6"
68-
@testset "sincospi(x::$T)" for T in (Float64, ComplexF64)
69-
Δz = Tangent{Tuple{T,T}}(randn(T), randn(T))
70-
71-
test_frule(sincospi, randn(T))
72-
test_rrule(sincospi, randn(T); output_tangent=Δz)
73-
end
74-
end
7567
end
7668
end
7769

@@ -192,16 +184,6 @@ const FASTABLE_AST = quote
192184
end
193185
end
194186

195-
@testset "copysign" begin
196-
# don't go too close to zero as the numerics may jump over it yielding wrong results
197-
@testset "at $y" for y in (-1.1, 0.1, 100.0)
198-
@testset "at $x" for x in (-1.1, -0.1, 33.0)
199-
test_frule(copysign, y, x)
200-
test_rrule(copysign, y, x)
201-
end
202-
end
203-
end
204-
205187
@testset "sign" begin
206188
@testset "real" begin
207189
@testset "at $x" for x in (-1.1, -1.1, 0.5, 100.0)

0 commit comments

Comments
 (0)