Skip to content

Commit ca4c9c4

Browse files
committed
Support SpecialFunctions 0.8
also improves tests. Also adds gradient for sign.
1 parent 115adfa commit ca4c9c4

File tree

4 files changed

+49
-39
lines changed

4 files changed

+49
-39
lines changed

test/rulesets/Base/base.jl

Lines changed: 19 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,9 @@
5252
test_scalar(acotd, 1/x)
5353
end
5454
@testset "Multivariate" begin
55-
x, y = rand(2)
5655
@testset "atan2" begin
5756
# https://en.wikipedia.org/wiki/Atan2
57+
x, y = rand(2)
5858
ratan = atan(x, y)
5959
u = x^2 + y^2
6060
datan = y/u - 2x/u
@@ -71,19 +71,11 @@
7171
end
7272

7373
@testset "sincos" begin
74-
rsincos = sincos(x)
75-
dsincos = cos(x) - 2sin(x)
76-
77-
r, pushforward = frule(sincos, x)
78-
@test r === rsincos
79-
df1, df2 = pushforward(NamedTuple(), 1)
80-
@test df1 + 2df2 === dsincos
81-
82-
r, pullback = rrule(sincos, x)
83-
@test r === rsincos
84-
ds, df = pullback(1, 2)
85-
@test df === dsincos
86-
@test ds === NO_FIELDS
74+
x, Δx, x̄ = randn(3)
75+
Δz = (randn(), randn())
76+
77+
frule_test(sincos, (x, Δx))
78+
rrule_test(sincos, Δz, (x, x̄))
8779
end
8880
end
8981
end # Trig
@@ -114,17 +106,16 @@
114106
end
115107

116108
@testset "Unary complex functions" begin
117-
for x in (-6, rand.((Float32, Float64, Complex{Float32}, Complex{Float64}))...)
118-
rtol = x isa Complex{Float32} ? 1e-6 : 1e-9
119-
test_scalar(real, x; rtol=rtol)
120-
test_scalar(imag, x; rtol=rtol)
109+
for x in (-4.1, 6.4, 1.0+0.5im, -10.0+1.5im)
110+
test_scalar(real, x)
111+
test_scalar(imag, x)
121112

122-
test_scalar(abs, x; rtol=rtol)
123-
test_scalar(hypot, x; rtol=rtol)
113+
test_scalar(abs, x)
114+
test_scalar(hypot, x)
124115

125-
test_scalar(angle, x; rtol=rtol)
126-
test_scalar(abs2, x; rtol=rtol)
127-
test_scalar(conj, x; rtol=rtol)
116+
test_scalar(angle, x)
117+
test_scalar(abs2, x)
118+
test_scalar(conj, x)
128119
end
129120
end
130121

@@ -146,14 +137,14 @@
146137
test_accumulation(rand(2, 5), dy)
147138
end
148139

149-
@testset "hypot(x, y)" begin
140+
@testset "binary trig ($f)" for f in (hypot, atan)
150141
rng = MersenneTwister(123456)
151-
x, Δx, x̄ = randn(rng, 3)
142+
x, Δx, x̄ = 10randn(rng, 3)
152143
y, Δy, ȳ = randn(rng, 3)
153144
Δz = randn(rng)
154145

155-
frule_test(hypot, (x, Δx), (y, Δy))
156-
rrule_test(hypot, Δz, (x, x̄), (y, ȳ))
146+
frule_test(f, (x, Δx), (y, Δy))
147+
rrule_test(f, Δz, (x, x̄), (y, ȳ))
157148
end
158149

159150
@testset "identity" begin
@@ -175,7 +166,7 @@
175166
@testset "Zero over the point discontinuity" begin
176167
# Can't do finite differencing because we are lying
177168
# following the subgradient convention.
178-
169+
179170
_, pb = rrule(sign, 0.0)
180171
_, x̄ = pb(10.5)
181172
@test extern(x̄) == 0

test/rulesets/packages/SpecialFunctions.jl

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
using SpecialFunctions
2-
#=
2+
33
@testset "SpecialFunctions" for x in (1, -1, 0, 0.5, 10, -17.1, 1.5 + 0.7im)
44
test_scalar(SpecialFunctions.erf, x)
55
test_scalar(SpecialFunctions.erfc, x)
@@ -33,7 +33,7 @@ using SpecialFunctions
3333
test_scalar(SpecialFunctions.trigamma, x)
3434
end
3535
end
36-
=#
36+
3737
@testset "log gamma and co" begin
3838
# SpecialFunctions 0.7->0.8 changes:
3939
for x in (1.5, 2.5, 10.5, -0.6, -2.6, 1.6+1.6im, 1.6-1.6im, -4.6+1.6im)
@@ -47,7 +47,12 @@ end
4747

4848
if isdefined(SpecialFunctions, :logabsgamma)
4949
isreal(x) || continue
50-
50+
51+
x, Δx, x̄ = randn(3)
52+
Δz = (randn(), randn())
53+
54+
frule_test(SpecialFunctions.logabsgamma, (x, Δx))
55+
rrule_test(SpecialFunctions.logabsgamma, Δz, (x, x̄))
5156
end
5257
end
5358

test/runtests.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,14 @@ println("Testing ChainRules.jl")
2020
@testset "ChainRules" begin
2121
include("helper_functions.jl")
2222
@testset "rulesets" begin
23+
2324
@testset "Base" begin
2425
include(joinpath("rulesets", "Base", "base.jl"))
25-
#include(joinpath("rulesets", "Base", "array.jl"))
26-
#include(joinpath("rulesets", "Base", "mapreduce.jl"))
27-
#include(joinpath("rulesets", "Base", "broadcast.jl"))
26+
include(joinpath("rulesets", "Base", "array.jl"))
27+
include(joinpath("rulesets", "Base", "mapreduce.jl"))
28+
include(joinpath("rulesets", "Base", "broadcast.jl"))
2829
end
29-
#==
30+
3031
print(" ")
3132

3233
@testset "Statistics" begin
@@ -41,7 +42,6 @@ println("Testing ChainRules.jl")
4142
include(joinpath("rulesets", "LinearAlgebra", "factorization.jl"))
4243
include(joinpath("rulesets", "LinearAlgebra", "blas.jl"))
4344
end
44-
==#
4545

4646
print(" ")
4747

test/test_util.jl

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,13 @@ function frule_test(f, xẋs::Tuple{Any, Any}...; rtol=1e-9, atol=1e-9, fdm=_fdm
9898

9999
# Correctness testing via finite differencing.
100100
dΩ_fd = jvp(fdm, xs->f(xs...), (xs, ẋs))
101-
@test isapprox(dΩ_ad, dΩ_fd; rtol=rtol, atol=atol, kwargs...)
101+
@test isapprox(
102+
collect(dΩ_ad), # Use collect so can use vector equality
103+
collect(dΩ_fd);
104+
rtol=rtol,
105+
atol=atol,
106+
kwargs...
107+
)
102108
end
103109

104110

@@ -108,6 +114,7 @@ end
108114
# Arguments
109115
- `f`: Function to which rule should be applied.
110116
- `ȳ`: adjoint w.r.t. output of `f` (should generally be set randomly).
117+
Should be same structure as `f(x)` (so if multiple returns should be a tuple)
111118
- `x`: input at which to evaluate `f` (should generally be set to an arbitary point in the domain).
112119
- `x̄`: currently accumulated adjoint (should generally be set randomly).
113120
@@ -118,8 +125,15 @@ function rrule_test(f, ȳ, (x, x̄)::Tuple{Any, Any}; rtol=1e-9, atol=1e-9, fdm
118125

119126
# Check correctness of evaluation.
120127
fx, pullback = ChainRules.rrule(f, x)
121-
@test fx f(x)
122-
(∂self, x̄_ad) = pullback(ȳ)
128+
@test collect(fx) collect(f(x)) # use collect so can do vector equality
129+
(∂self, x̄_ad) = if fx isa Tuple
130+
# If the function returned multiple values,
131+
# then it must have multiple seeds for propagating backwards
132+
pullback(ȳ...)
133+
else
134+
pullback(ȳ)
135+
end
136+
123137
@test ∂self === NO_FIELDS # No internal fields
124138
# Correctness testing via finite differencing.
125139
x̄_fd = j′vp(fdm, f, ȳ, x)

0 commit comments

Comments
 (0)