Skip to content

Commit c6e84c1

Browse files
authored
Merge pull request JuliaDiff#87 from JuliaDiff/ox/testmore
Add FiniteDifferences based scalar rule tests
2 parents 4aa2c8b + be46c03 commit c6e84c1

File tree

6 files changed

+188
-91
lines changed

6 files changed

+188
-91
lines changed

src/rulesets/Base/base.jl

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,54 +6,70 @@
66
@scalar_rule(log2(x), inv(x) / log(oftype(x, 2)))
77
@scalar_rule(log1p(x), inv(x + 1))
88
@scalar_rule(expm1(x), exp(x))
9+
910
@scalar_rule(sin(x), cos(x))
1011
@scalar_rule(cos(x), -sin(x))
1112
@scalar_rule(sinpi(x), π * cospi(x))
1213
@scalar_rule(cospi(x), -π * sinpi(x))
1314
@scalar_rule(sind(x), (π / oftype(x, 180)) * cosd(x))
1415
@scalar_rule(cosd(x), -/ oftype(x, 180)) * sind(x))
16+
1517
@scalar_rule(asin(x), inv(sqrt(1 - x^2)))
1618
@scalar_rule(acos(x), -inv(sqrt(1 - x^2)))
1719
@scalar_rule(atan(x), inv(1 + x^2))
18-
@scalar_rule(asec(x), inv(abs(x) * sqrt(x^2 - 1)))
19-
@scalar_rule(acsc(x), -inv(abs(x) * sqrt(x^2 - 1)))
20+
@scalar_rule(asec(x::Real), inv(abs(x) * sqrt(x^2 - 1)))
21+
@scalar_rule(asec(x), inv(x^2 * sqrt(1 - x^-2)))
22+
@scalar_rule(acsc(x::Real), -inv(abs(x) * sqrt(x^2 - 1)))
23+
@scalar_rule(acsc(x), -inv(x^2 * sqrt(1 - x^-2)))
2024
@scalar_rule(acot(x), -inv(1 + x^2))
25+
2126
@scalar_rule(asind(x), oftype(x, 180) / π / sqrt(1 - x^2))
2227
@scalar_rule(acosd(x), -oftype(x, 180) / π / sqrt(1 - x^2))
2328
@scalar_rule(atand(x), oftype(x, 180) / π / (1 + x^2))
24-
@scalar_rule(asecd(x), oftype(x, 180) / π / abs(x) / sqrt(x^2 - 1))
25-
@scalar_rule(acscd(x), -oftype(x, 180) / π / abs(x) / sqrt(x^2 - 1))
29+
@scalar_rule(asecd(x::Real), oftype(x, 180) / π / abs(x) / sqrt(x^2 - 1))
30+
@scalar_rule(asecd(x), oftype(x, 180) / π / x^2 / sqrt(1 - x^-2))
31+
@scalar_rule(acscd(x::Real), -oftype(x, 180) / π / abs(x) / sqrt(x^2 - 1))
32+
@scalar_rule(acscd(x), -oftype(x, 180) / π / x^2 / sqrt(1 - x^-2))
2633
@scalar_rule(acotd(x), -oftype(x, 180) / π / (1 + x^2))
34+
2735
@scalar_rule(sinh(x), cosh(x))
2836
@scalar_rule(cosh(x), sinh(x))
2937
@scalar_rule(tanh(x), sech(x)^2)
3038
@scalar_rule(coth(x), -(csch(x)^2))
39+
3140
@scalar_rule(asinh(x), inv(sqrt(x^2 + 1)))
3241
@scalar_rule(acosh(x), inv(sqrt(x^2 - 1)))
3342
@scalar_rule(atanh(x), inv(1 - x^2))
3443
@scalar_rule(asech(x), -inv(x * sqrt(1 - x^2)))
35-
@scalar_rule(acsch(x), -inv(abs(x) * sqrt(1 + x^2)))
44+
@scalar_rule(acsch(x::Real), -inv(abs(x) * sqrt(1 + x^2)))
45+
@scalar_rule(acsch(x), -inv(x^2 * sqrt(1 + x^-2)))
3646
@scalar_rule(acoth(x), inv(1 - x^2))
47+
3748
@scalar_rule(deg2rad(x), π / oftype(x, 180))
3849
@scalar_rule(rad2deg(x), oftype(x, 180) / π)
50+
3951
@scalar_rule(conj(x), Wirtinger(Zero(), One()))
4052
@scalar_rule(adjoint(x), Wirtinger(Zero(), One()))
4153
@scalar_rule(transpose(x), One())
54+
4255
@scalar_rule(abs(x), sign(x))
4356
@scalar_rule(rem2pi(x, r::RoundingMode), (One(), DNE()))
57+
4458
@scalar_rule(+(x), One())
4559
@scalar_rule(-(x), -1)
4660
@scalar_rule(+(x, y), (One(), One()))
4761
@scalar_rule(-(x, y), (One(), -1))
4862
@scalar_rule(/(x, y), (inv(y), -(x / y / y)))
4963
@scalar_rule(\(x, y), (-(y / x / x), inv(x)))
5064
@scalar_rule(^(x, y), (y * x^(y - 1), Ω * log(x)))
51-
@scalar_rule(inv(x), -abs2(Ω))
65+
66+
@scalar_rule(inv(x), -Ω^2)
5267
@scalar_rule(sqrt(x), inv(2 * Ω))
5368
@scalar_rule(cbrt(x), inv(3 * Ω^2))
5469
@scalar_rule(exp(x), Ω)
5570
@scalar_rule(exp2(x), Ω * log(oftype(x, 2)))
5671
@scalar_rule(exp10(x), Ω * log(oftype(x, 10)))
72+
5773
@scalar_rule(tan(x), 1 + Ω^2)
5874
@scalar_rule(sec(x), Ω * tan(x))
5975
@scalar_rule(csc(x), -Ω * cot(x))
@@ -64,9 +80,11 @@
6480
@scalar_rule(cotd(x), -/ oftype(x, 180)) * (1 + Ω^2))
6581
@scalar_rule(sech(x), -tanh(x) * Ω)
6682
@scalar_rule(csch(x), -coth(x) * Ω)
83+
6784
@scalar_rule(hypot(x, y), (x / Ω, y / Ω))
6885
@scalar_rule(sincos(x), @setup((sinx, cosx) = Ω), cosx, -sinx)
6986
@scalar_rule(atan(x, y), @setup(u = x^2 + y^2), (y / u, -x / u))
87+
7088
@scalar_rule(max(x, y), @setup(gt = x > y), (gt, !gt))
7189
@scalar_rule(min(x, y), @setup(gt = x > y), (!gt, gt))
7290
@scalar_rule(mod(x, y), @setup((u, nan) = promote(x / y, NaN16)),

test/rulesets/Base/base.jl

Lines changed: 91 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -1,66 +1,55 @@
1-
function test_scalar(f, f′, xs...)
2-
for r = (rrule, frule)
3-
rr = r(f, xs...)
4-
@test rr !== nothing
5-
fx, ∂x = rr
6-
@test fx == f(xs...)
7-
@test ∂x(1) f′(xs...) atol=1e-5
8-
end
9-
end
10-
111
@testset "base" begin
122
@testset "Trig" begin
13-
@testset "Basics" for x = (Float64(π), Complex(π, π/2))
14-
test_scalar(sin, cos, x)
15-
test_scalar(cos, x -> -sin(x), x)
16-
test_scalar(tan, x -> 1 + tan(x)^2, x)
17-
test_scalar(sec, x -> sec(x) * tan(x), x)
18-
test_scalar(csc, x -> -csc(x) * cot(x), x)
19-
test_scalar(cot, x -> -1 - cot(x)^2, x)
20-
test_scalar(sinpi, x -> π * cospi(x), x)
21-
test_scalar(cospi, x -> -π * sinpi(x), x)
3+
@testset "Basics" for x = (Float64(π)-0.01, Complex(π, π/2))
4+
test_scalar(sin, x)
5+
test_scalar(cos, x)
6+
test_scalar(tan, x)
7+
test_scalar(sec, x)
8+
test_scalar(csc, x)
9+
test_scalar(cot, x)
10+
test_scalar(sinpi, x)
11+
test_scalar(cospi, x)
2212
end
23-
@testset "Hyperbolic" for x = (Float64(π), Complex(π, π/2))
24-
test_scalar(sinh, cosh, x)
25-
test_scalar(cosh, sinh, x)
26-
test_scalar(tanh, x -> sech(x)^2, x)
27-
test_scalar(sech, x -> -tanh(x) * sech(x), x)
28-
test_scalar(csch, x -> -coth(x) * csch(x), x)
29-
test_scalar(coth, x -> -csch(x)^2, x)
13+
@testset "Hyperbolic" for x = (Float64(π)-0.01, Complex-0.01, π/2))
14+
test_scalar(sinh, x)
15+
test_scalar(cosh, x)
16+
test_scalar(tanh, x)
17+
test_scalar(sech, x)
18+
test_scalar(csch, x)
19+
test_scalar(coth, x)
3020
end
3121
@testset "Degrees" begin
3222
x = 45.0
33-
test_scalar(sind, x ->/ 180) * cosd(x), x)
34-
test_scalar(cosd, x -> (-π / 180) * sind(x), x)
35-
test_scalar(tand, x ->/ 180) * (1 + tand(x)^2), x)
36-
test_scalar(secd, x ->/ 180) * secd(x) * tand(x), x)
37-
test_scalar(cscd, x -> (-π / 180) * cscd(x) * cotd(x), x)
38-
test_scalar(cotd, x -> (-π / 180) * (1 + cotd(x)^2), x)
23+
test_scalar(sind, x)
24+
test_scalar(cosd, x)
25+
test_scalar(tand, x)
26+
test_scalar(secd, x)
27+
test_scalar(cscd, x)
28+
test_scalar(cotd, x)
3929
end
40-
@testset "Inverses" for x = (1.0, Complex(1.0, 0.25))
41-
test_scalar(asin, x -> 1 / sqrt(1 - x^2), x)
42-
test_scalar(acos, x -> -1 / sqrt(1 - x^2), x)
43-
test_scalar(atan, x -> 1 / (1 + x^2), x)
44-
test_scalar(asec, x -> 1 / (abs(x) * sqrt(x^2 - 1)), x)
45-
test_scalar(acsc, x -> -1 / (abs(x) * sqrt(x^2 - 1)), x)
46-
test_scalar(acot, x -> -1 / (1 + x^2), x)
30+
@testset "Inverses" for x = (0.5, Complex(0.5, 0.25))
31+
test_scalar(asin, x)
32+
test_scalar(acos, x)
33+
test_scalar(atan, x)
34+
test_scalar(asec, 1/x)
35+
test_scalar(acsc, 1/x)
36+
test_scalar(acot, 1/x)
4737
end
48-
@testset "Inverse hyperbolic" for x = (0.0, Complex(0.0, 0.25))
49-
test_scalar(asinh, x -> 1 / sqrt(x^2 + 1), x)
50-
test_scalar(acosh, x -> 1 / sqrt(x^2 - 1), x + 1) # +1 accounts for domain
51-
test_scalar(atanh, x -> 1 / (1 - x^2), x)
52-
test_scalar(asech, x -> -1 / x / sqrt(1 - x^2), x)
53-
test_scalar(acsch, x -> -1 / abs(x) / sqrt(1 + x^2), x)
54-
test_scalar(acoth, x -> 1 / (1 - x^2), x + 1)
38+
@testset "Inverse hyperbolic" for x = (0.5, Complex(0.5, 0.25))
39+
test_scalar(asinh, x)
40+
test_scalar(acosh, x + 1) # +1 accounts for domain
41+
test_scalar(atanh, x)
42+
test_scalar(asech, x)
43+
test_scalar(acsch, x)
44+
test_scalar(acoth, x + 1)
5545
end
56-
@testset "Inverse degrees" begin
57-
x = 1.0
58-
test_scalar(asind, x -> 180 / π / sqrt(1 - x^2), x)
59-
test_scalar(acosd, x -> -180 / π / sqrt(1 - x^2), x)
60-
test_scalar(atand, x -> 180 / π / (1 + x^2), x)
61-
test_scalar(asecd, x -> 180 / π / abs(x) / sqrt(x^2 - 1), x)
62-
test_scalar(acscd, x -> -180 / π / abs(x) / sqrt(x^2 - 1), x)
63-
test_scalar(acotd, x -> -180 / π / (1 + x^2), x)
46+
@testset "Inverse degrees" for x = (0.5, Complex(0.5, 0.25))
47+
test_scalar(asind, x)
48+
test_scalar(acosd, x)
49+
test_scalar(atand, x)
50+
test_scalar(asecd, 1/x)
51+
test_scalar(acscd, 1/x)
52+
test_scalar(acotd, 1/x)
6453
end
6554
@testset "Multivariate" begin
6655
x, y = rand(2)
@@ -83,40 +72,63 @@ end
8372
@test r === rsincos
8473
@test df(1, 2) === dsincos
8574
end
86-
end
87-
@testset "Misc. Tests" begin
88-
@testset "*(x, y)" begin
89-
x, y = rand(3, 2), rand(2, 5)
90-
z, (dx, dy) = rrule(*, x, y)
75+
end # Trig
9176

92-
@test z == x * y
77+
@testset "math" begin
78+
for x in (-0.1, 6.4, 1.0+0.5im, -10.0+0im)
79+
test_scalar(deg2rad, x)
80+
test_scalar(rad2deg, x)
9381

94-
= rand(3, 5)
82+
test_scalar(inv, x)
9583

96-
@test dx(z̄) == extern(accumulate(zeros(3, 2), dx, z̄))
97-
@test dy(z̄) == extern(accumulate(zeros(2, 5), dy, z̄))
84+
test_scalar(exp, x)
85+
test_scalar(exp2, x)
86+
test_scalar(exp10, x)
9887

99-
test_accumulation(rand(3, 2), dx, z̄, z̄ * y')
100-
test_accumulation(rand(2, 5), dy, z̄, x' * z̄)
88+
x isa Real && test_scalar(cbrt, x)
89+
if (x isa Real && x >= 0) || x isa Complex
90+
test_scalar(sqrt, x)
91+
test_scalar(log, x)
92+
test_scalar(log2, x)
93+
test_scalar(log10, x)
94+
test_scalar(log1p, x)
95+
end
10196
end
102-
@testset "hypot(x, y)" begin
103-
x, y = rand(2)
104-
h, dxy = frule(hypot, x, y)
97+
end
10598

106-
@test extern(dxy(One(), Zero())) === x / h
107-
@test extern(dxy(Zero(), One())) === y / h
99+
@testset "*(x, y)" begin
100+
x, y = rand(3, 2), rand(2, 5)
101+
z, (dx, dy) = rrule(*, x, y)
108102

109-
cx, cy = cast((One(), Zero())), cast((Zero(), One()))
110-
dx, dy = extern(dxy(cx, cy))
111-
@test dx === x / h
112-
@test dy === y / h
103+
@test z == x * y
113104

114-
cx, cy = cast((rand(), Zero())), cast((Zero(), rand()))
115-
dx, dy = extern(dxy(cx, cy))
116-
@test dx === x / h * cx.value[1]
117-
@test dy === y / h * cy.value[2]
118-
end
105+
= rand(3, 5)
106+
107+
@test dx(z̄) == extern(accumulate(zeros(3, 2), dx, z̄))
108+
@test dy(z̄) == extern(accumulate(zeros(2, 5), dy, z̄))
109+
110+
test_accumulation(rand(3, 2), dx, z̄, z̄ * y')
111+
test_accumulation(rand(2, 5), dy, z̄, x' * z̄)
119112
end
113+
114+
@testset "hypot(x, y)" begin
115+
x, y = rand(2)
116+
h, dxy = frule(hypot, x, y)
117+
118+
@test extern(dxy(One(), Zero())) === x / h
119+
@test extern(dxy(Zero(), One())) === y / h
120+
121+
cx, cy = cast((One(), Zero())), cast((Zero(), One()))
122+
dx, dy = extern(dxy(cx, cy))
123+
@test dx === x / h
124+
@test dy === y / h
125+
126+
cx, cy = cast((rand(), Zero())), cast((Zero(), rand()))
127+
dx, dy = extern(dxy(cx, cy))
128+
@test dx === x / h * cx.value[1]
129+
@test dy === y / h * cy.value[2]
130+
end
131+
120132
@testset "identity" begin
121133
rng = MersenneTwister(1)
122134
n = 4

test/rulesets/Base/broadcast.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,13 @@
88
@test extern(dx(One())) == cos.(x)
99

1010
x̄, ȳ = rand(), rand()
11-
@test extern(accumulate(x̄, dx, ȳ)) ==.+.* cos.(x)
11+
@test isequal(
12+
extern(ChainRules.accumulate(x̄, dx, ȳ)),
13+
.+.* cos.(x)
14+
)
1215

1316
x̄, ȳ = Zero(), rand(3, 3)
1417
@test extern(accumulate(x̄, dx, ȳ)) ==.* cos.(x)
15-
16-
x̄, ȳ = Zero(), cast(rand(3, 3))
17-
@test extern(accumulate(x̄, dx, ȳ)) == extern(ȳ) .* cos.(x)
1818
end
1919
end
2020
end
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
using SpecialFunctions
2+
3+
@testset "SpecialFunctions" for x in (1, -1, 0, 0.5, 10, -17.1, 1.5 + 0.7im)
4+
test_scalar(SpecialFunctions.erf, x)
5+
test_scalar(SpecialFunctions.erfc, x)
6+
test_scalar(SpecialFunctions.erfi, x)
7+
8+
test_scalar(SpecialFunctions.airyai, x)
9+
test_scalar(SpecialFunctions.airyaiprime, x)
10+
test_scalar(SpecialFunctions.airybi, x)
11+
test_scalar(SpecialFunctions.airybiprime, x)
12+
13+
test_scalar(SpecialFunctions.besselj0, x)
14+
test_scalar(SpecialFunctions.besselj1, x)
15+
16+
test_scalar(SpecialFunctions.erfcx, x)
17+
test_scalar(SpecialFunctions.dawson, x)
18+
19+
if x isa Real
20+
test_scalar(SpecialFunctions.invdigamma, x)
21+
end
22+
23+
if x isa Real && 0 < x < 1
24+
test_scalar(SpecialFunctions.erfinv, x)
25+
test_scalar(SpecialFunctions.erfcinv, x)
26+
end
27+
28+
if x isa Real && x > 0 || x isa Complex
29+
test_scalar(SpecialFunctions.bessely0, x)
30+
test_scalar(SpecialFunctions.bessely1, x)
31+
test_scalar(SpecialFunctions.gamma, x)
32+
test_scalar(SpecialFunctions.digamma, x)
33+
test_scalar(SpecialFunctions.trigamma, x)
34+
test_scalar(SpecialFunctions.lgamma, x)
35+
end
36+
end

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ using ChainRulesCore: add, cast, extern, accumulate, accumulate!, store!, @scala
1717

1818
include("test_util.jl")
1919

20+
println("Testing ChainRules.jl")
2021
@testset "ChainRules" begin
2122
include("helper_functions.jl")
2223
@testset "rulesets" begin

0 commit comments

Comments
 (0)