Skip to content

Commit 74fd16f

Browse files
authored
Add more ChainRules derivatives (#348)
* Add more ChainRules definitions * Bump version * Use irrational constants * Convert irrational manually with `oftype`
1 parent 1b7a377 commit 74fd16f

File tree

2 files changed

+89
-30
lines changed

2 files changed

+89
-30
lines changed

src/chainrules.jl

Lines changed: 38 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@ https://github.com/JuliaMath/SpecialFunctions.jl/issues/321
1616
"""
1717

1818
ChainRulesCore.@scalar_rule(airyai(x), airyaiprime(x))
19+
ChainRulesCore.@scalar_rule(airyaix(x), airyaiprimex(x) + sqrt(x) * Ω)
1920
ChainRulesCore.@scalar_rule(airyaiprime(x), x * airyai(x))
21+
ChainRulesCore.@scalar_rule(airyaiprimex(x), x * airyaix(x) + sqrt(x) * Ω)
2022
ChainRulesCore.@scalar_rule(airybi(x), airybiprime(x))
2123
ChainRulesCore.@scalar_rule(airybiprime(x), x * airybi(x))
2224
ChainRulesCore.@scalar_rule(besselj0(x), -besselj1(x))
@@ -31,12 +33,18 @@ ChainRulesCore.@scalar_rule(
3133
)
3234
ChainRulesCore.@scalar_rule(dawson(x), 1 - (2 * x * Ω))
3335
ChainRulesCore.@scalar_rule(digamma(x), trigamma(x))
34-
ChainRulesCore.@scalar_rule(erf(x), (2 / sqrt(π)) * exp(-x * x))
35-
ChainRulesCore.@scalar_rule(erfc(x), -(2 / sqrt(π)) * exp(-x * x))
36-
ChainRulesCore.@scalar_rule(erfcinv(x), -(sqrt(π) / 2) * exp^2))
37-
ChainRulesCore.@scalar_rule(erfcx(x), (2 * x * Ω) - (2 / sqrt(π)))
38-
ChainRulesCore.@scalar_rule(erfi(x), (2 / sqrt(π)) * exp(x * x))
39-
ChainRulesCore.@scalar_rule(erfinv(x), (sqrt(π) / 2) * exp^2))
36+
37+
# TODO: use `invsqrtπ` if it is added to IrrationalConstants
38+
ChainRulesCore.@scalar_rule(erf(x), (2 * exp(-x^2)) / sqrtπ)
39+
ChainRulesCore.@scalar_rule(erf(x, y), (- (2 * exp(-x^2)) / sqrtπ, (2 * exp(-y^2)) / sqrtπ))
40+
ChainRulesCore.@scalar_rule(erfc(x), - (2 * exp(-x^2)) / sqrtπ)
41+
ChainRulesCore.@scalar_rule(logerfc(x), - (2 * exp(-x^2 - Ω)) / sqrtπ)
42+
ChainRulesCore.@scalar_rule(erfcinv(x), - (sqrtπ * (exp^2) / 2)))
43+
ChainRulesCore.@scalar_rule(erfcx(x), 2 * (x * Ω - inv(oftype(Ω, sqrtπ))))
44+
ChainRulesCore.@scalar_rule(logerfcx(x), 2 * (x - exp(-Ω) / sqrtπ))
45+
ChainRulesCore.@scalar_rule(erfi(x), (2 * exp(x^2)) / sqrtπ)
46+
ChainRulesCore.@scalar_rule(erfinv(x), sqrtπ * (exp^2) / 2))
47+
4048
ChainRulesCore.@scalar_rule(gamma(x), Ω * digamma(x))
4149
ChainRulesCore.@scalar_rule(
4250
gamma(a, x),
@@ -65,7 +73,7 @@ ChainRulesCore.@scalar_rule(
6573
)
6674
ChainRulesCore.@scalar_rule(trigamma(x), polygamma(2, x))
6775

68-
# binary
76+
# Bessel functions
6977
ChainRulesCore.@scalar_rule(
7078
besselj(ν, x),
7179
(
@@ -94,20 +102,42 @@ ChainRulesCore.@scalar_rule(
94102
-(besselk- 1, x) + besselk+ 1, x)) / 2,
95103
),
96104
)
105+
ChainRulesCore.@scalar_rule(
106+
besselkx(ν, x),
107+
(
108+
ChainRulesCore.@not_implemented(BESSEL_ORDER_INFO),
109+
-(besselkx- 1, x) + besselkx+ 1, x)) / 2 + Ω,
110+
),
111+
)
97112
ChainRulesCore.@scalar_rule(
98113
hankelh1(ν, x),
99114
(
100115
ChainRulesCore.@not_implemented(BESSEL_ORDER_INFO),
101116
(hankelh1- 1, x) - hankelh1+ 1, x)) / 2,
102117
),
103118
)
119+
ChainRulesCore.@scalar_rule(
120+
hankelh1x(ν, x),
121+
(
122+
ChainRulesCore.@not_implemented(BESSEL_ORDER_INFO),
123+
(hankelh1x- 1, x) - hankelh1x+ 1, x)) / 2 - im * Ω,
124+
),
125+
)
104126
ChainRulesCore.@scalar_rule(
105127
hankelh2(ν, x),
106128
(
107129
ChainRulesCore.@not_implemented(BESSEL_ORDER_INFO),
108130
(hankelh2- 1, x) - hankelh2+ 1, x)) / 2,
109131
),
110132
)
133+
ChainRulesCore.@scalar_rule(
134+
hankelh2x(ν, x),
135+
(
136+
ChainRulesCore.@not_implemented(BESSEL_ORDER_INFO),
137+
(hankelh2x- 1, x) - hankelh2x+ 1, x)) / 2 + im * Ω,
138+
),
139+
)
140+
111141
ChainRulesCore.@scalar_rule(
112142
polygamma(m, x),
113143
(
@@ -161,5 +191,5 @@ ChainRulesCore.@scalar_rule(
161191
)
162192
)
163193
ChainRulesCore.@scalar_rule(expinti(x), exp(x) / x)
164-
ChainRulesCore.@scalar_rule(sinint(x), sinc(x / π))
194+
ChainRulesCore.@scalar_rule(sinint(x), sinc(invπ * x))
165195
ChainRulesCore.@scalar_rule(cosint(x), cos(x) / x)

test/chainrules.jl

Lines changed: 51 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -5,17 +5,20 @@
55
for x in (1.0, -1.0, 0.0, 0.5, 10.0, -17.1, 1.5 + 0.7im)
66
test_scalar(erf, x)
77
test_scalar(erfc, x)
8+
test_scalar(erfcx, x)
89
test_scalar(erfi, x)
910

1011
test_scalar(airyai, x)
1112
test_scalar(airyaiprime, x)
1213
test_scalar(airybi, x)
1314
test_scalar(airybiprime, x)
1415

15-
test_scalar(erfcx, x)
1616
test_scalar(dawson, x)
1717

1818
if x isa Real
19+
test_scalar(logerfc, x)
20+
test_scalar(logerfcx, x)
21+
1922
test_scalar(invdigamma, x)
2023
end
2124

@@ -28,6 +31,11 @@
2831
test_scalar(gamma, x)
2932
test_scalar(digamma, x)
3033
test_scalar(trigamma, x)
34+
35+
if x isa Real
36+
test_scalar(airyaix, x)
37+
test_scalar(airyaiprimex, x)
38+
end
3139
end
3240
end
3341
end
@@ -51,31 +59,38 @@
5159

5260
test_frule(besselk, nu, x)
5361
test_rrule(besselk, nu, x)
62+
test_frule(besselkx, nu, x)
63+
test_rrule(besselkx, nu, x)
5464

5565
test_frule(bessely, nu, x)
5666
test_rrule(bessely, nu, x)
5767

58-
# use complex numbers in `rrule` for FiniteDifferences
5968
test_frule(hankelh1, nu, x)
60-
test_rrule(hankelh1, nu, complex(x))
69+
test_rrule(hankelh1, nu, x)
70+
test_frule(hankelh1x, nu, x)
71+
test_rrule(hankelh1x, nu, x)
6172

62-
# use complex numbers in `rrule` for FiniteDifferences
6373
test_frule(hankelh2, nu, x)
64-
test_rrule(hankelh2, nu, complex(x))
74+
test_rrule(hankelh2, nu, x)
75+
test_frule(hankelh2x, nu, x)
76+
test_rrule(hankelh2x, nu, x)
6577
end
6678
end
6779
end
6880

69-
@testset "beta and logbeta" begin
81+
@testset "erf, beta, and logbeta" begin
7082
test_points = (1.5, 2.5, 10.5, 1.6 + 1.6im, 1.6 - 1.6im, 4.6 + 1.6im)
71-
for _x in test_points, _y in test_points
72-
# ensure all complex if any complex for FiniteDifferences
73-
x, y = promote(_x, _y)
83+
for x in test_points, y in test_points
7484
test_frule(beta, x, y)
7585
test_rrule(beta, x, y)
7686

7787
test_frule(logbeta, x, y)
7888
test_rrule(logbeta, x, y)
89+
90+
if x isa Real && y isa Real
91+
test_frule(erf, x, y)
92+
test_rrule(erf, x, y)
93+
end
7994
end
8095
end
8196

@@ -91,13 +106,11 @@
91106
isreal(x) && x < 0 && continue
92107
test_scalar(loggamma, x)
93108
for a in test_points
94-
# ensure all complex if any complex for FiniteDifferences
95-
_a, _x = promote(a, x)
96-
test_frule(gamma, _a, _x; rtol=1e-8)
97-
test_rrule(gamma, _a, _x; rtol=1e-8)
109+
test_frule(gamma, a, x; rtol=1e-8)
110+
test_rrule(gamma, a, x; rtol=1e-8)
98111

99-
test_frule(loggamma, _a, _x)
100-
test_rrule(loggamma, _a, _x)
112+
test_frule(loggamma, a, x)
113+
test_rrule(loggamma, a, x)
101114
end
102115

103116
isreal(x) || continue
@@ -117,14 +130,11 @@
117130
test_scalar(expintx, x)
118131

119132
for nu in (-1.5, 2.2, 4.0)
120-
# ensure all complex if any complex for FiniteDifferences
121-
_x, _nu = promote(x, nu)
133+
test_frule(expint, nu, x)
134+
test_rrule(expint, nu, x)
122135

123-
test_frule(expint, _nu, _x)
124-
test_rrule(expint, _nu, _x)
125-
126-
test_frule(expintx, _nu, _x)
127-
test_rrule(expintx, _nu, _x)
136+
test_frule(expintx, nu, x)
137+
test_rrule(expintx, nu, x)
128138
end
129139

130140
isreal(x) || continue
@@ -133,4 +143,23 @@
133143
test_scalar(cosint, x)
134144
end
135145
end
146+
147+
# https://github.com/JuliaMath/SpecialFunctions.jl/issues/307
148+
@testset "promotions" begin
149+
# one argument
150+
for f in (erf, erfc, logerfc, erfcinv, erfcx, logerfcx, erfi, erfinv, sinint)
151+
_, ẏ = frule((NoTangent(), 1f0), f, 1f0)
152+
@testisa Float32
153+
_, back = rrule(f, 1f0)
154+
_, x̄ = back(1f0)
155+
@testisa Float32
156+
end
157+
158+
# two arguments
159+
_, ẏ = frule((NoTangent(), 1f0, 1f0), erf, 1f0, 1f0)
160+
@testisa Float32
161+
_, back = rrule(erf, 1f0, 1f0)
162+
_, x̄ = back(1f0)
163+
@testisa Float32
164+
end
136165
end

0 commit comments

Comments
 (0)