Skip to content

Commit cb59302

Browse files
committed
Change existing scalar rule tests to use FiniteDifferences and correct the incorrect acshc asec acsc, asecd acscd rules
1 parent 7f8af88 commit cb59302

File tree

3 files changed

+98
-60
lines changed

3 files changed

+98
-60
lines changed

src/rulesets/Base/base.jl

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,54 +4,70 @@
44
@scalar_rule(log2(x), inv(x) / log(oftype(x, 2)))
55
@scalar_rule(log1p(x), inv(x + 1))
66
@scalar_rule(expm1(x), exp(x))
7+
78
@scalar_rule(sin(x), cos(x))
89
@scalar_rule(cos(x), -sin(x))
910
@scalar_rule(sinpi(x), π * cospi(x))
1011
@scalar_rule(cospi(x), -π * sinpi(x))
1112
@scalar_rule(sind(x), (π / oftype(x, 180)) * cosd(x))
1213
@scalar_rule(cosd(x), -/ oftype(x, 180)) * sind(x))
14+
1315
@scalar_rule(asin(x), inv(sqrt(1 - x^2)))
1416
@scalar_rule(acos(x), -inv(sqrt(1 - x^2)))
1517
@scalar_rule(atan(x), inv(1 + x^2))
16-
@scalar_rule(asec(x), inv(abs(x) * sqrt(x^2 - 1)))
17-
@scalar_rule(acsc(x), -inv(abs(x) * sqrt(x^2 - 1)))
18+
@scalar_rule(asec(x::Real), inv(abs(x) * sqrt(x^2 - 1)))
19+
@scalar_rule(asec(x), inv(x^2 * sqrt(1 - x^-2)))
20+
@scalar_rule(acsc(x::Real), -inv(abs(x) * sqrt(x^2 - 1)))
21+
@scalar_rule(acsc(x), -inv(x^2 * sqrt(1 - x^-2)))
1822
@scalar_rule(acot(x), -inv(1 + x^2))
23+
1924
@scalar_rule(asind(x), oftype(x, 180) / π / sqrt(1 - x^2))
2025
@scalar_rule(acosd(x), -oftype(x, 180) / π / sqrt(1 - x^2))
2126
@scalar_rule(atand(x), oftype(x, 180) / π / (1 + x^2))
22-
@scalar_rule(asecd(x), oftype(x, 180) / π / abs(x) / sqrt(x^2 - 1))
23-
@scalar_rule(acscd(x), -oftype(x, 180) / π / abs(x) / sqrt(x^2 - 1))
27+
@scalar_rule(asecd(x::Real), oftype(x, 180) / π / abs(x) / sqrt(x^2 - 1))
28+
@scalar_rule(asecd(x), oftype(x, 180) / π / x^2 / sqrt(1 - x^-2))
29+
@scalar_rule(acscd(x::Real), -oftype(x, 180) / π / abs(x) / sqrt(x^2 - 1))
30+
@scalar_rule(acscd(x), -oftype(x, 180) / π / x^2 / sqrt(1 - x^-2))
2431
@scalar_rule(acotd(x), -oftype(x, 180) / π / (1 + x^2))
32+
2533
@scalar_rule(sinh(x), cosh(x))
2634
@scalar_rule(cosh(x), sinh(x))
2735
@scalar_rule(tanh(x), sech(x)^2)
2836
@scalar_rule(coth(x), -(csch(x)^2))
37+
2938
@scalar_rule(asinh(x), inv(sqrt(x^2 + 1)))
3039
@scalar_rule(acosh(x), inv(sqrt(x^2 - 1)))
3140
@scalar_rule(atanh(x), inv(1 - x^2))
3241
@scalar_rule(asech(x), -inv(x * sqrt(1 - x^2)))
33-
@scalar_rule(acsch(x), -inv(abs(x) * sqrt(1 + x^2)))
42+
@scalar_rule(acsch(x::Real), -inv(abs(x) * sqrt(1 + x^2)))
43+
@scalar_rule(acsch(x), -inv(x^2 * sqrt(1 + x^-2)))
3444
@scalar_rule(acoth(x), inv(1 - x^2))
45+
3546
@scalar_rule(deg2rad(x), π / oftype(x, 180))
3647
@scalar_rule(rad2deg(x), oftype(x, 180) / π)
48+
3749
@scalar_rule(conj(x), Wirtinger(Zero(), One()))
3850
@scalar_rule(adjoint(x), Wirtinger(Zero(), One()))
3951
@scalar_rule(transpose(x), One())
52+
4053
@scalar_rule(abs(x), sign(x))
4154
@scalar_rule(rem2pi(x, r::RoundingMode), (One(), DNE()))
55+
4256
@scalar_rule(+(x), One())
4357
@scalar_rule(-(x), -1)
4458
@scalar_rule(+(x, y), (One(), One()))
4559
@scalar_rule(-(x, y), (One(), -1))
4660
@scalar_rule(/(x, y), (inv(y), -(x / y / y)))
4761
@scalar_rule(\(x, y), (-(y / x / x), inv(x)))
4862
@scalar_rule(^(x, y), (y * x^(y - 1), Ω * log(x)))
63+
4964
@scalar_rule(inv(x), -abs2(Ω))
5065
@scalar_rule(sqrt(x), inv(2 * Ω))
5166
@scalar_rule(cbrt(x), inv(3 * Ω^2))
5267
@scalar_rule(exp(x), Ω)
5368
@scalar_rule(exp2(x), Ω * log(oftype(x, 2)))
5469
@scalar_rule(exp10(x), Ω * log(oftype(x, 10)))
70+
5571
@scalar_rule(tan(x), 1 + Ω^2)
5672
@scalar_rule(sec(x), Ω * tan(x))
5773
@scalar_rule(csc(x), -Ω * cot(x))
@@ -62,9 +78,11 @@
6278
@scalar_rule(cotd(x), -/ oftype(x, 180)) * (1 + Ω^2))
6379
@scalar_rule(sech(x), -tanh(x) * Ω)
6480
@scalar_rule(csch(x), -coth(x) * Ω)
81+
6582
@scalar_rule(hypot(x, y), (x / Ω, y / Ω))
6683
@scalar_rule(sincos(x), @setup((sinx, cosx) = Ω), cosx, -sinx)
6784
@scalar_rule(atan(x, y), @setup(u = x^2 + y^2), (y / u, -x / u))
85+
6886
@scalar_rule(max(x, y), @setup(gt = x > y), (gt, !gt))
6987
@scalar_rule(min(x, y), @setup(gt = x > y), (!gt, gt))
7088
@scalar_rule(mod(x, y), @setup((u, nan) = promote(x / y, NaN16)),

test/rulesets/Base/base.jl

Lines changed: 43 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1,66 +1,56 @@
1-
function test_scalar(f, f′, xs...)
2-
for r = (rrule, frule)
3-
rr = r(f, xs...)
4-
@test rr !== nothing
5-
fx, ∂x = rr
6-
@test fx == f(xs...)
7-
@test ∂x(1) f′(xs...) atol=1e-5
8-
end
9-
end
10-
111
@testset "base" begin
122
@testset "Trig" begin
13-
@testset "Basics" for x = (Float64(π), Complex(π, π/2))
14-
test_scalar(sin, cos, x)
15-
test_scalar(cos, x -> -sin(x), x)
16-
test_scalar(tan, x -> 1 + tan(x)^2, x)
17-
test_scalar(sec, x -> sec(x) * tan(x), x)
18-
test_scalar(csc, x -> -csc(x) * cot(x), x)
19-
test_scalar(cot, x -> -1 - cot(x)^2, x)
20-
test_scalar(sinpi, x -> π * cospi(x), x)
21-
test_scalar(cospi, x -> -π * sinpi(x), x)
3+
@testset "Basics" for x = (Float64(π)-0.01, Complex(π, π/2))
4+
test_scalar(sin, x)
5+
test_scalar(cos, x)
6+
test_scalar(tan, x)
7+
test_scalar(sec, x)
8+
test_scalar(csc, x)
9+
test_scalar(cot, x)
10+
test_scalar(sinpi, x)
11+
test_scalar(cospi, x)
2212
end
23-
@testset "Hyperbolic" for x = (Float64(π), Complex(π, π/2))
24-
test_scalar(sinh, cosh, x)
25-
test_scalar(cosh, sinh, x)
26-
test_scalar(tanh, x -> sech(x)^2, x)
27-
test_scalar(sech, x -> -tanh(x) * sech(x), x)
28-
test_scalar(csch, x -> -coth(x) * csch(x), x)
29-
test_scalar(coth, x -> -csch(x)^2, x)
13+
@testset "Hyperbolic" for x = (Float64(π)-0.01, Complex-0.01, π/2))
14+
test_scalar(sinh, x)
15+
test_scalar(cosh, x)
16+
test_scalar(tanh, x)
17+
test_scalar(sech, x)
18+
test_scalar(csch, x)
19+
test_scalar(coth, x)
3020
end
3121
@testset "Degrees" begin
3222
x = 45.0
33-
test_scalar(sind, x ->/ 180) * cosd(x), x)
34-
test_scalar(cosd, x -> (-π / 180) * sind(x), x)
35-
test_scalar(tand, x ->/ 180) * (1 + tand(x)^2), x)
36-
test_scalar(secd, x ->/ 180) * secd(x) * tand(x), x)
37-
test_scalar(cscd, x -> (-π / 180) * cscd(x) * cotd(x), x)
38-
test_scalar(cotd, x -> (-π / 180) * (1 + cotd(x)^2), x)
23+
test_scalar(sind, x)
24+
test_scalar(cosd, x)
25+
test_scalar(tand, x)
26+
test_scalar(secd, x)
27+
test_scalar(cscd, x)
28+
test_scalar(cotd, x)
3929
end
40-
@testset "Inverses" for x = (1.0, Complex(1.0, 0.25))
41-
test_scalar(asin, x -> 1 / sqrt(1 - x^2), x)
42-
test_scalar(acos, x -> -1 / sqrt(1 - x^2), x)
43-
test_scalar(atan, x -> 1 / (1 + x^2), x)
44-
test_scalar(asec, x -> 1 / (abs(x) * sqrt(x^2 - 1)), x)
45-
test_scalar(acsc, x -> -1 / (abs(x) * sqrt(x^2 - 1)), x)
46-
test_scalar(acot, x -> -1 / (1 + x^2), x)
30+
@testset "Inverses" for x = (0.5, Complex(0.5, 0.25))
31+
test_scalar(asin, x)
32+
test_scalar(acos, x)
33+
test_scalar(atan, x)
34+
test_scalar(asec, 1/x)
35+
test_scalar(acsc, 1/x)
36+
test_scalar(acot, 1/x)
4737
end
48-
@testset "Inverse hyperbolic" for x = (0.0, Complex(0.0, 0.25))
49-
test_scalar(asinh, x -> 1 / sqrt(x^2 + 1), x)
50-
test_scalar(acosh, x -> 1 / sqrt(x^2 - 1), x + 1) # +1 accounts for domain
51-
test_scalar(atanh, x -> 1 / (1 - x^2), x)
52-
test_scalar(asech, x -> -1 / x / sqrt(1 - x^2), x)
53-
test_scalar(acsch, x -> -1 / abs(x) / sqrt(1 + x^2), x)
54-
test_scalar(acoth, x -> 1 / (1 - x^2), x + 1)
38+
@testset "Inverse hyperbolic" for x = (0.5, Complex(0.5, 0.25))
39+
test_scalar(asinh, x)
40+
test_scalar(acosh, x + 1) # +1 accounts for domain
41+
test_scalar(atanh, x)
42+
test_scalar(asech, x)
43+
test_scalar(acsch, x)
44+
test_scalar(acoth, x + 1)
5545
end
5646
@testset "Inverse degrees" begin
57-
x = 1.0
58-
test_scalar(asind, x -> 180 / π / sqrt(1 - x^2), x)
59-
test_scalar(acosd, x -> -180 / π / sqrt(1 - x^2), x)
60-
test_scalar(atand, x -> 180 / π / (1 + x^2), x)
61-
test_scalar(asecd, x -> 180 / π / abs(x) / sqrt(x^2 - 1), x)
62-
test_scalar(acscd, x -> -180 / π / abs(x) / sqrt(x^2 - 1), x)
63-
test_scalar(acotd, x -> -180 / π / (1 + x^2), x)
47+
x = 0.5
48+
test_scalar(asind, x)
49+
test_scalar(acosd, x)
50+
test_scalar(atand, x)
51+
test_scalar(asecd, 1/x)
52+
test_scalar(acscd, 1/x)
53+
test_scalar(acotd, 1/x)
6454
end
6555
@testset "Multivariate" begin
6656
x, y = rand(2)

test/test_util.jl

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,43 @@
1+
using FiniteDifferences, Test
12
using FiniteDifferences: jvp, j′vp
3+
using ChainRules
24

35
const _fdm = central_fdm(5, 1)
46

7+
"""
8+
test_scalar(f, x; rtol=1e-9, atol=1e-9, fdm=central_fdm(5, 1), kwargs...)
9+
10+
Given a function `f` with scalar input an scalar output, perform finite differencing checks,
11+
at input point `x` to confirm that there are correct ChainRules provided.
12+
13+
# Arguments
14+
- `f`: Function for which the `frule` and `rrule` should be tested.
15+
- `x`: input at which to evaluate `f` (should generally be set to an arbitary point in the domain).
16+
17+
All keyword arguments except for `fdm` are passed to `isapprox`.
18+
"""
19+
function test_scalar(f, x; rtol=1e-9, atol=1e-9, fdm=_fdm, kwargs...)
20+
@testset "$f at $x, $(nameof(rule))" for rule in (rrule, frule)
21+
res = rule(f, x)
22+
@test res !== nothing # Check the rule was defined
23+
fx, ∂x = res
24+
@test fx == f(x) # Check we still get the normal value, right
25+
26+
# Check that we get the derivative right:
27+
@test isapprox(
28+
∂x(1), fdm(f, x);
29+
rtol=rtol, atol=atol, kwargs...
30+
)
31+
end
32+
end
33+
34+
535
"""
636
frule_test(f, (x, ẋ)...; rtol=1e-9, atol=1e-9, fdm=central_fdm(5, 1), kwargs...)
737
838
# Arguments
939
- `f`: Function for which the `frule` should be tested.
10-
- `x`: input at which to evaluate `f` (should generally be set randomly).
40+
- `x`: input at which to evaluate `f` (should generally be set to an arbitary point in the domain).
1141
- `ẋ`: differential w.r.t. `x` (should generally be set randomly).
1242
1343
All keyword arguments except for `fdm` are passed to `isapprox`.
@@ -31,7 +61,7 @@ end
3161
# Arguments
3262
- `f`: Function to which rule should be applied.
3363
- `ȳ`: adjoint w.r.t. output of `f` (should generally be set randomly).
34-
- `x`: input at which to evaluate `f` (should generally be set randomly).
64+
- `x`: input at which to evaluate `f` (should generally be set to an arbitary point in the domain).
3565
- `x̄`: currently accumulated adjoint (should generally be set randomly).
3666
3767
All keyword arguments except for `fdm` are passed to `isapprox`.

0 commit comments

Comments
 (0)