Skip to content

Commit daff86a

Browse files
authored
Update rules for SpecialFunctions (#407)
1 parent 647c440 commit daff86a

File tree

3 files changed

+138
-42
lines changed

3 files changed

+138
-42
lines changed

Project.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ChainRules"
22
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
3-
version = "0.7.61"
3+
version = "0.7.62"
44

55
[deps]
66
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
@@ -12,8 +12,8 @@ Requires = "ae029012-a4dd-5104-9daa-d747884805df"
1212
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1313

1414
[compat]
15-
ChainRulesCore = "0.9.29"
16-
ChainRulesTestUtils = "0.6.6"
15+
ChainRulesCore = "0.9.40"
16+
ChainRulesTestUtils = "0.6.8"
1717
Compat = "3"
1818
FiniteDifferences = "0.11, 0.12"
1919
Reexport = "0.2, 1"

src/rulesets/packages/SpecialFunctions.jl

Lines changed: 52 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
1+
const BESSEL_ORDER_INFO = """
2+
derivatives of Bessel functions with respect to the order are not implemented currently:
3+
https://github.com/JuliaMath/SpecialFunctions.jl/issues/160
4+
"""
5+
16
@scalar_rule(SpecialFunctions.airyai(x), SpecialFunctions.airyaiprime(x))
27
@scalar_rule(SpecialFunctions.airyaiprime(x), x * SpecialFunctions.airyai(x))
38
@scalar_rule(SpecialFunctions.airybi(x), SpecialFunctions.airybiprime(x))
@@ -30,43 +35,60 @@
3035
# binary
3136
@scalar_rule(
3237
SpecialFunctions.besselj(ν, x),
33-
(NaN, (SpecialFunctions.besselj- 1, x) - SpecialFunctions.besselj+ 1, x)) / 2),
38+
(
39+
@not_implemented(BESSEL_ORDER_INFO),
40+
(SpecialFunctions.besselj- 1, x) - SpecialFunctions.besselj+ 1, x)) / 2
41+
),
3442
)
3543
@scalar_rule(
3644
SpecialFunctions.besseli(ν, x),
37-
(NaN, (SpecialFunctions.besseli- 1, x) + SpecialFunctions.besseli+ 1, x)) / 2),
45+
(
46+
@not_implemented(BESSEL_ORDER_INFO),
47+
(SpecialFunctions.besseli- 1, x) + SpecialFunctions.besseli+ 1, x)) / 2,
48+
),
3849
)
3950
@scalar_rule(
4051
SpecialFunctions.bessely(ν, x),
41-
(NaN, (SpecialFunctions.bessely- 1, x) - SpecialFunctions.bessely+ 1, x)) / 2),
52+
(
53+
@not_implemented(BESSEL_ORDER_INFO),
54+
(SpecialFunctions.bessely- 1, x) - SpecialFunctions.bessely+ 1, x)) / 2,
55+
),
4256
)
4357
@scalar_rule(
4458
SpecialFunctions.besselk(ν, x),
45-
(NaN, -(SpecialFunctions.besselk- 1, x) + SpecialFunctions.besselk+ 1, x)) / 2),
59+
(
60+
@not_implemented(BESSEL_ORDER_INFO),
61+
-(SpecialFunctions.besselk- 1, x) + SpecialFunctions.besselk+ 1, x)) / 2,
62+
),
4663
)
4764
@scalar_rule(
4865
SpecialFunctions.hankelh1(ν, x),
49-
(NaN, (SpecialFunctions.hankelh1- 1, x) - SpecialFunctions.hankelh1+ 1, x)) / 2),
66+
(
67+
@not_implemented(BESSEL_ORDER_INFO),
68+
(SpecialFunctions.hankelh1- 1, x) - SpecialFunctions.hankelh1+ 1, x)) / 2,
69+
),
5070
)
5171
@scalar_rule(
5272
SpecialFunctions.hankelh2(ν, x),
53-
(NaN, (SpecialFunctions.hankelh2- 1, x) - SpecialFunctions.hankelh2+ 1, x)) / 2),
73+
(
74+
@not_implemented(BESSEL_ORDER_INFO),
75+
(SpecialFunctions.hankelh2- 1, x) - SpecialFunctions.hankelh2+ 1, x)) / 2,
76+
),
5477
)
5578
@scalar_rule(
5679
SpecialFunctions.polygamma(m, x),
57-
(NaN, SpecialFunctions.polygamma(m + 1, x))
80+
(
81+
DoesNotExist(),
82+
SpecialFunctions.polygamma(m + 1, x),
83+
),
5884
)
5985
# todo: setup for common expr
6086
@scalar_rule(
6187
SpecialFunctions.beta(a, b),
6288
*(SpecialFunctions.digamma(a) - SpecialFunctions.digamma(a + b)),
6389
Ω*(SpecialFunctions.digamma(b) - SpecialFunctions.digamma(a + b)),)
6490
)
65-
@scalar_rule(
66-
SpecialFunctions.lbeta(a, b),
67-
(SpecialFunctions.digamma(a) - SpecialFunctions.digamma(a + b),
68-
SpecialFunctions.digamma(b) - SpecialFunctions.digamma(a + b),)
69-
)
91+
7092
# Changes between SpecialFunctions 0.7 and 0.8
7193
if isdefined(SpecialFunctions, :lgamma)
7294
# actually is the absolute value of the logorithm of gamma
@@ -81,3 +103,21 @@ end
81103
if isdefined(SpecialFunctions, :loggamma)
82104
@scalar_rule(SpecialFunctions.loggamma(x), SpecialFunctions.digamma(x))
83105
end
106+
107+
if isdefined(SpecialFunctions, :lbeta)
108+
# todo: setup for common expr
109+
@scalar_rule(
110+
SpecialFunctions.lbeta(a, b),
111+
(SpecialFunctions.digamma(a) - SpecialFunctions.digamma(a + b),
112+
SpecialFunctions.digamma(b) - SpecialFunctions.digamma(a + b),)
113+
)
114+
end
115+
116+
if isdefined(SpecialFunctions, :logbeta)
117+
# todo: setup for common expr
118+
@scalar_rule(
119+
SpecialFunctions.logbeta(a, b),
120+
(SpecialFunctions.digamma(a) - SpecialFunctions.digamma(a + b),
121+
SpecialFunctions.digamma(b) - SpecialFunctions.digamma(a + b),)
122+
)
123+
end
Lines changed: 83 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,53 +1,109 @@
1-
@testset "SpecialFunctions" for x in (1.0, -1.0, 0.0, 0.5, 10.0, -17.1, 1.5 + 0.7im)
2-
test_scalar(SpecialFunctions.erf, x)
3-
test_scalar(SpecialFunctions.erfc, x)
4-
test_scalar(SpecialFunctions.erfi, x)
1+
@testset "general: single input" begin
2+
for x in (1.0, -1.0, 0.0, 0.5, 10.0, -17.1, 1.5 + 0.7im)
3+
test_scalar(SpecialFunctions.erf, x)
4+
test_scalar(SpecialFunctions.erfc, x)
5+
test_scalar(SpecialFunctions.erfi, x)
56

6-
test_scalar(SpecialFunctions.airyai, x)
7-
test_scalar(SpecialFunctions.airyaiprime, x)
8-
test_scalar(SpecialFunctions.airybi, x)
9-
test_scalar(SpecialFunctions.airybiprime, x)
7+
test_scalar(SpecialFunctions.airyai, x)
8+
test_scalar(SpecialFunctions.airyaiprime, x)
9+
test_scalar(SpecialFunctions.airybi, x)
10+
test_scalar(SpecialFunctions.airybiprime, x)
1011

11-
test_scalar(SpecialFunctions.besselj0, x)
12-
test_scalar(SpecialFunctions.besselj1, x)
12+
test_scalar(SpecialFunctions.erfcx, x)
13+
test_scalar(SpecialFunctions.dawson, x)
1314

14-
test_scalar(SpecialFunctions.erfcx, x)
15-
test_scalar(SpecialFunctions.dawson, x)
15+
if x isa Real
16+
test_scalar(SpecialFunctions.invdigamma, x)
17+
end
1618

17-
if x isa Real
18-
test_scalar(SpecialFunctions.invdigamma, x)
19-
end
19+
if x isa Real && 0 < x < 1
20+
test_scalar(SpecialFunctions.erfinv, x)
21+
test_scalar(SpecialFunctions.erfcinv, x)
22+
end
2023

21-
if x isa Real && 0 < x < 1
22-
test_scalar(SpecialFunctions.erfinv, x)
23-
test_scalar(SpecialFunctions.erfcinv, x)
24+
if x isa Real && x > 0 || x isa Complex
25+
test_scalar(SpecialFunctions.gamma, x)
26+
test_scalar(SpecialFunctions.digamma, x)
27+
test_scalar(SpecialFunctions.trigamma, x)
28+
end
2429
end
30+
end
31+
32+
@testset "Bessel functions" begin
33+
for x in (1.5, 2.5, 10.5, -0.6, -2.6, -3.3, 1.6 + 1.6im, 1.6 - 1.6im, -4.6 + 1.6im)
34+
test_scalar(SpecialFunctions.besselj0, x)
35+
test_scalar(SpecialFunctions.besselj1, x)
36+
37+
isreal(x) && x < 0 && continue
2538

26-
if x isa Real && x > 0 || x isa Complex
2739
test_scalar(SpecialFunctions.bessely0, x)
2840
test_scalar(SpecialFunctions.bessely1, x)
29-
test_scalar(SpecialFunctions.gamma, x)
30-
test_scalar(SpecialFunctions.digamma, x)
31-
test_scalar(SpecialFunctions.trigamma, x)
41+
42+
for nu in (-1.5, 2.2, 4.0)
43+
test_frule(SpecialFunctions.besseli, nu, x)
44+
test_rrule(SpecialFunctions.besseli, nu, x)
45+
46+
test_frule(SpecialFunctions.besselj, nu, x)
47+
test_rrule(SpecialFunctions.besselj, nu, x)
48+
49+
test_frule(SpecialFunctions.besselk, nu, x)
50+
test_rrule(SpecialFunctions.besselk, nu, x)
51+
52+
test_frule(SpecialFunctions.bessely, nu, x)
53+
test_rrule(SpecialFunctions.bessely, nu, x)
54+
55+
# use complex numbers in `rrule` for FiniteDifferences
56+
test_frule(SpecialFunctions.hankelh1, nu, x)
57+
test_rrule(SpecialFunctions.hankelh1, nu, complex(x))
58+
59+
# use complex numbers in `rrule` for FiniteDifferences
60+
test_frule(SpecialFunctions.hankelh2, nu, x)
61+
test_rrule(SpecialFunctions.hankelh2, nu, complex(x))
62+
end
63+
end
64+
end
65+
66+
@testset "beta and logbeta" begin
67+
test_points = (1.5, 2.5, 10.5, 1.6 + 1.6im, 1.6 - 1.6im, 4.6 + 1.6im)
68+
for _x in test_points, _y in test_points
69+
# ensure all complex if any complex for FiniteDifferences
70+
x, y = promote(_x, _y)
71+
test_frule(SpecialFunctions.beta, x, y)
72+
test_rrule(SpecialFunctions.beta, x, y)
73+
74+
if isdefined(SpecialFunctions, :lbeta)
75+
test_frule(SpecialFunctions.lbeta, x, y)
76+
test_rrule(SpecialFunctions.lbeta, x, y)
77+
end
78+
79+
if isdefined(SpecialFunctions, :logbeta)
80+
test_frule(SpecialFunctions.logbeta, x, y)
81+
test_rrule(SpecialFunctions.logbeta, x, y)
82+
end
3283
end
3384
end
3485

35-
# SpecialFunctions 0.7->0.8 changes:
3686
@testset "log gamma and co" begin
37-
#It is important that we have negative numbers with both odd and even integer parts
38-
for x in (1.5, 2.5, 10.5, -0.6, -2.6, -3.3, 1.6+1.6im, 1.6-1.6im, -4.6+1.6im)
87+
# It is important that we have negative numbers with both odd and even integer parts
88+
for x in (1.5, 2.5, 10.5, -0.6, -2.6, -3.3, 1.6 + 1.6im, 1.6 - 1.6im, -4.6 + 1.6im)
89+
for m in (0, 1, 2, 3)
90+
test_frule(SpecialFunctions.polygamma, m, x)
91+
test_rrule(SpecialFunctions.polygamma, m, x)
92+
end
93+
3994
if isdefined(SpecialFunctions, :lgamma)
4095
test_scalar(SpecialFunctions.lgamma, x)
4196
end
97+
4298
if isdefined(SpecialFunctions, :loggamma)
4399
isreal(x) && x < 0 && continue
44100
test_scalar(SpecialFunctions.loggamma, x)
45101
end
46102

47103
if isdefined(SpecialFunctions, :logabsgamma)
48104
isreal(x) || continue
49-
test_frule(logabsgamma, x)
50-
test_rrule(logabsgamma, x; output_tangent=(randn(), randn()))
105+
test_frule(SpecialFunctions.logabsgamma, x)
106+
test_rrule(SpecialFunctions.logabsgamma, x; output_tangent=(randn(), randn()))
51107
end
52108
end
53109
end

0 commit comments

Comments
 (0)