Skip to content

Commit e7ad852

Browse files
authored
Merge pull request #121 from JuliaDiff/ox/specfunup3
Support SpecialFunction 0.8 + other improvements
2 parents 4136420 + 3189653 commit e7ad852

File tree

6 files changed

+98
-32
lines changed

6 files changed

+98
-32
lines changed

src/rulesets/Base/base.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
@scalar_rule(one(x), Zero())
22
@scalar_rule(zero(x), Zero())
3+
@scalar_rule(sign(x), Zero())
4+
35
@scalar_rule(abs2(x), Wirtinger(x', x))
46
@scalar_rule(log(x), inv(x))
57
@scalar_rule(log10(x), inv(x) / log(oftype(x, 10)))

src/rulesets/packages/SpecialFunctions.jl

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@ using ChainRulesCore
33
using ..SpecialFunctions
44

55

6-
@scalar_rule(SpecialFunctions.lgamma(x), SpecialFunctions.digamma(x))
76
@scalar_rule(SpecialFunctions.erf(x), (2 / sqrt(π)) * exp(-x * x))
87
@scalar_rule(SpecialFunctions.erfc(x), -(2 / sqrt(π)) * exp(-x * x))
98
@scalar_rule(SpecialFunctions.erfi(x), (2 / sqrt(π)) * exp(x * x))
@@ -24,4 +23,19 @@ using ..SpecialFunctions
2423
@scalar_rule(SpecialFunctions.erfcx(x), (2 * x * Ω) - (2 / sqrt(π)))
2524
@scalar_rule(SpecialFunctions.dawson(x), 1 - (2 * x * Ω))
2625

26+
# Changes between SpecialFunctions 0.7 and 0.8
27+
if isdefined(SpecialFunctions, :lgamma)
28+
# actually is the absolute value of the logorithm of gamma
29+
@scalar_rule(SpecialFunctions.lgamma(x), SpecialFunctions.digamma(x))
30+
end
31+
32+
if isdefined(SpecialFunctions, :logabsgamma)
33+
# actually is the absolute value of the logorithm of gamma, paired with sign gamma
34+
@scalar_rule(SpecialFunctions.logabsgamma(x), SpecialFunctions.digamma(x), Zero())
35+
end
36+
37+
if isdefined(SpecialFunctions, :loggamma)
38+
@scalar_rule(SpecialFunctions.loggamma(x), SpecialFunctions.digamma(x))
39+
end
40+
2741
end #module

test/rulesets/Base/base.jl

Lines changed: 37 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,9 @@
5252
test_scalar(acotd, 1/x)
5353
end
5454
@testset "Multivariate" begin
55-
x, y = rand(2)
5655
@testset "atan2" begin
5756
# https://en.wikipedia.org/wiki/Atan2
57+
x, y = rand(2)
5858
ratan = atan(x, y)
5959
u = x^2 + y^2
6060
datan = y/u - 2x/u
@@ -71,19 +71,11 @@
7171
end
7272

7373
@testset "sincos" begin
74-
rsincos = sincos(x)
75-
dsincos = cos(x) - 2sin(x)
76-
77-
r, pushforward = frule(sincos, x)
78-
@test r === rsincos
79-
df1, df2 = pushforward(NamedTuple(), 1)
80-
@test df1 + 2df2 === dsincos
81-
82-
r, pullback = rrule(sincos, x)
83-
@test r === rsincos
84-
ds, df = pullback(1, 2)
85-
@test df === dsincos
86-
@test ds === NO_FIELDS
74+
x, Δx, x̄ = randn(3)
75+
Δz = (randn(), randn())
76+
77+
frule_test(sincos, (x, Δx))
78+
rrule_test(sincos, Δz, (x, x̄))
8779
end
8880
end
8981
end # Trig
@@ -114,17 +106,16 @@
114106
end
115107

116108
@testset "Unary complex functions" begin
117-
for x in (-6, rand.((Float32, Float64, Complex{Float32}, Complex{Float64}))...)
118-
rtol = x isa Complex{Float32} ? 1e-6 : 1e-9
119-
test_scalar(real, x; rtol=rtol)
120-
test_scalar(imag, x; rtol=rtol)
109+
for x in (-4.1, 6.4, 1.0+0.5im, -10.0+1.5im)
110+
test_scalar(real, x)
111+
test_scalar(imag, x)
121112

122-
test_scalar(abs, x; rtol=rtol)
123-
test_scalar(hypot, x; rtol=rtol)
113+
test_scalar(abs, x)
114+
test_scalar(hypot, x)
124115

125-
test_scalar(angle, x; rtol=rtol)
126-
test_scalar(abs2, x; rtol=rtol)
127-
test_scalar(conj, x; rtol=rtol)
116+
test_scalar(angle, x)
117+
test_scalar(abs2, x)
118+
test_scalar(conj, x)
128119
end
129120
end
130121

@@ -146,14 +137,14 @@
146137
test_accumulation(rand(2, 5), dy)
147138
end
148139

149-
@testset "hypot(x, y)" begin
140+
@testset "binary trig ($f)" for f in (hypot, atan)
150141
rng = MersenneTwister(123456)
151-
x, Δx, x̄ = randn(rng, 3)
142+
x, Δx, x̄ = 10randn(rng, 3)
152143
y, Δy, ȳ = randn(rng, 3)
153144
Δz = randn(rng)
154145

155-
frule_test(hypot, (x, Δx), (y, Δy))
156-
rrule_test(hypot, Δz, (x, x̄), (y, ȳ))
146+
frule_test(f, (x, Δx), (y, Δy))
147+
rrule_test(f, Δz, (x, x̄), (y, ȳ))
157148
end
158149

159150
@testset "identity" begin
@@ -166,4 +157,23 @@
166157
test_scalar(one, x)
167158
test_scalar(zero, x)
168159
end
160+
161+
@testset "sign" begin
162+
@testset "at points" for x in (-1.1, -1.1, 0.5, 100)
163+
test_scalar(sign, x)
164+
end
165+
166+
@testset "Zero over the point discontinuity" begin
167+
# Can't do finite differencing because we are lying
168+
# following the subgradient convention.
169+
170+
_, pb = rrule(sign, 0.0)
171+
_, x̄ = pb(10.5)
172+
@test extern(x̄) == 0
173+
174+
_, pf = frule(sign, 0.0)
175+
= pf(NamedTuple(), 10.5)
176+
@test extern(ẏ) == 0
177+
end
178+
end
169179
end

test/rulesets/packages/SpecialFunctions.jl

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,29 @@ using SpecialFunctions
3131
test_scalar(SpecialFunctions.gamma, x)
3232
test_scalar(SpecialFunctions.digamma, x)
3333
test_scalar(SpecialFunctions.trigamma, x)
34-
test_scalar(SpecialFunctions.lgamma, x)
34+
end
35+
end
36+
37+
# SpecialFunctions 0.7->0.8 changes:
38+
@testset "log gamma and co" begin
39+
#It is important that we have negative numbers with both odd and even integer parts
40+
for x in (1.5, 2.5, 10.5, -0.6, -2.6, -3.3, 1.6+1.6im, 1.6-1.6im, -4.6+1.6im)
41+
if isdefined(SpecialFunctions, :lgamma)
42+
test_scalar(SpecialFunctions.lgamma, x)
43+
end
44+
if isdefined(SpecialFunctions, :loggamma)
45+
isreal(x) && x < 0 && continue
46+
test_scalar(SpecialFunctions.loggamma, x)
47+
end
48+
49+
if isdefined(SpecialFunctions, :logabsgamma)
50+
isreal(x) || continue
51+
52+
Δx, x̄ = randn(2)
53+
Δz = (randn(), randn())
54+
55+
frule_test(SpecialFunctions.logabsgamma, (x, Δx))
56+
rrule_test(SpecialFunctions.logabsgamma, Δz, (x, x̄))
57+
end
3558
end
3659
end

test/runtests.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,15 @@ using ChainRulesCore: extern, accumulate, accumulate!, store!, @scalar_rule,
1414
Wirtinger, wirtinger_primal, wirtinger_conjugate,
1515
Zero, One, DNE, Thunk, AbstractDifferential
1616

17+
Random.seed!(1) # Set seed that all testsets should reset to.
18+
1719
include("test_util.jl")
1820

1921
println("Testing ChainRules.jl")
2022
@testset "ChainRules" begin
2123
include("helper_functions.jl")
2224
@testset "rulesets" begin
25+
2326
@testset "Base" begin
2427
include(joinpath("rulesets", "Base", "base.jl"))
2528
include(joinpath("rulesets", "Base", "array.jl"))

test/test_util.jl

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,13 @@ function frule_test(f, xẋs::Tuple{Any, Any}...; rtol=1e-9, atol=1e-9, fdm=_fdm
9898

9999
# Correctness testing via finite differencing.
100100
dΩ_fd = jvp(fdm, xs->f(xs...), (xs, ẋs))
101-
@test isapprox(dΩ_ad, dΩ_fd; rtol=rtol, atol=atol, kwargs...)
101+
@test isapprox(
102+
collect(dΩ_ad), # Use collect so can use vector equality
103+
collect(dΩ_fd);
104+
rtol=rtol,
105+
atol=atol,
106+
kwargs...
107+
)
102108
end
103109

104110

@@ -108,6 +114,7 @@ end
108114
# Arguments
109115
- `f`: Function to which rule should be applied.
110116
- `ȳ`: adjoint w.r.t. output of `f` (should generally be set randomly).
117+
Should be same structure as `f(x)` (so if multiple returns should be a tuple)
111118
- `x`: input at which to evaluate `f` (should generally be set to an arbitary point in the domain).
112119
- `x̄`: currently accumulated adjoint (should generally be set randomly).
113120
@@ -118,8 +125,15 @@ function rrule_test(f, ȳ, (x, x̄)::Tuple{Any, Any}; rtol=1e-9, atol=1e-9, fdm
118125

119126
# Check correctness of evaluation.
120127
fx, pullback = ChainRules.rrule(f, x)
121-
@test fx f(x)
122-
(∂self, x̄_ad) = pullback(ȳ)
128+
@test collect(fx) collect(f(x)) # use collect so can do vector equality
129+
(∂self, x̄_ad) = if fx isa Tuple
130+
# If the function returned multiple values,
131+
# then it must have multiple seeds for propagating backwards
132+
pullback(ȳ...)
133+
else
134+
pullback(ȳ)
135+
end
136+
123137
@test ∂self === NO_FIELDS # No internal fields
124138
# Correctness testing via finite differencing.
125139
x̄_fd = j′vp(fdm, f, ȳ, x)

0 commit comments

Comments
 (0)