Skip to content

Commit fe14fa2

Browse files
committed
add tests from ForwardDiff
1 parent 9928324 commit fe14fa2

File tree

3 files changed

+222
-0
lines changed

3 files changed

+222
-0
lines changed

test/Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2"
33
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
44
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
5+
DiffTests = "de460e47-3fe3-5279-bb4a-814414816d5d"
56
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
67
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
78
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
@@ -19,6 +20,7 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
1920
ChainRules = "1.5"
2021
ChainRulesCore = "1.2"
2122
Combinatorics = "1"
23+
DiffTests = "0.1.1"
2224
StaticArrays = "1"
2325
StatsBase = "0.33"
2426
StructArrays = "0.6"

test/forwarddiff.jl

Lines changed: 217 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,217 @@
1+
2+
# This file contains tests adapted from FowardDiff.jl, many of them using DiffTests.jl
3+
# Organised here file-by-file in alphabetical order!
4+
5+
#####
6+
##### setup
7+
#####
8+
9+
using Test, LinearAlgebra
10+
using ForwardDiff, DiffTests
11+
using Diffractor, ChainRulesCore
12+
13+
fwd_derivative(f, x::Number) = Diffractor.PrimeDerivativeFwd(f)(x) |> unthunk
14+
function rev_derivative(f, x::Number)
15+
y = f(x)
16+
if y isa Number
17+
Diffractor.PrimeDerivativeBack(f)(x)
18+
elseif y isa AbstractArray
19+
map(CartesianIndices(y)) do I
20+
Diffractor.PrimeDerivativeBack(x -> f(x)[I])(x) |> unthunk
21+
end
22+
else
23+
throw("can't handle f(x)::$(typeof(y))")
24+
end
25+
end
26+
27+
@test ForwardDiff.derivative(abs2, 3) == 6
28+
@test fwd_derivative(abs2, 3) == 6
29+
@test rev_derivative(abs2, 3) == 6
30+
31+
@test ForwardDiff.derivative(x -> fill(x,2,3), 7) == [1 1 1; 1 1 1]
32+
@test fwd_derivative(x -> fill(x,2,3), 7) == [1 1 1; 1 1 1]
33+
@test rev_derivative(x -> fill(x,2,3), 7) == [1 1 1; 1 1 1]
34+
35+
DERIVATIVES = (ForwardDiff.derivative, fwd_derivative, rev_derivative)
36+
37+
function fwd_gradient(f, x::AbstractVector)
38+
map(eachindex(x)) do i
39+
fwd_derivative-> f(vcat(x[begin:i-1], ξ, x[i+1:end])), x[i])
40+
end
41+
end
42+
fwd_gradient(f, x::AbstractArray) = reshape(fwd_gradient(v -> f(reshape(v, size(x))), vec(x)), size(x))
43+
rev_gradient(f, x::AbstractArray) = ChainRulesCore.unthunk(Diffractor.PrimeDerivativeBack(f)(float(x)))
44+
45+
@test ForwardDiff.gradient(prod, [1,2,3]) == [6,3,2]
46+
@test_broken fwd_gradient(prod, [1,2,3]) == [6,3,2] # ERROR: MethodError: no method matching arrayset(::Bool, ::Vector{Int64}, ::Int64, ::Int64)
47+
@test rev_gradient(prod, [1,2,3]) == [6,3,2]
48+
49+
@test fwd_gradient(sum, [1,2]) == [1,1]
50+
@test_broken fwd_gradient(first, [1,1]) == [1,0]
51+
52+
GRADIENTS = (ForwardDiff.gradient, rev_gradient)
53+
54+
fwd_jacobian(f, x::AbstractArray) = hcat(vec.(fwd_gradient(f, x))...)
55+
function rev_jacobian(f, x::AbstractArray)
56+
y = f(x)
57+
slices = map(LinearIndices(y)) do i # fails if y isa Number, just like ForwardDiff.jacobian
58+
vec(rev_gradient(x -> f(x)[i], x))
59+
end
60+
vcat(transpose(slices)...)
61+
# permutedims(hcat(slices...))
62+
end
63+
64+
@test ForwardDiff.jacobian(x -> x[1:2], [1,2,3]) == [1 0 0; 0 1 0]
65+
@test_broken fwd_jacobian(x -> x[1:2], [1,2,3]) == [1 0 0; 0 1 0]
66+
@test rev_jacobian(x -> x[1:2], [1,2,3]) == [1 0 0; 0 1 0]
67+
68+
JACOBIANS = (ForwardDiff.jacobian, rev_jacobian)
69+
70+
fwd_hessian(f, x::AbstractArray) = fwd_jacobian(y -> fwd_gradient(f, y), x)
71+
rev_hessian(f, x::AbstractArray) = rev_jacobian(y -> rev_gradient(f, y), x)
72+
fwd_rev_hessian(f, x::AbstractArray) = fwd_jacobian(y -> rev_gradient(f, y), x)
73+
rev_fwd_hessian(f, x::AbstractArray) = rev_jacobian(y -> fwd_gradient(f, y), x)
74+
75+
@test ForwardDiff.hessian(x -> -log(x[1]), [2,3]) == [0.25 0; 0 0]
76+
@test rev_hessian(x -> -log(x[1]), [2,3]) == [0.25 0; 0 0]
77+
78+
HESSIANS = (ForwardDiff.hessian, rev_hessian)
79+
80+
81+
# const XLEN = 13
82+
# const YLEN = 7
83+
# const X, Y = rand(XLEN), rand(YLEN)
84+
# const CHUNK_SIZES = (1, div(DEFAULT_CHUNK_THRESHOLD, 3), div(DEFAULT_CHUNK_THRESHOLD, 2), DEFAULT_CHUNK_THRESHOLD, DEFAULT_CHUNK_THRESHOLD + 1)
85+
# const HESSIAN_CHUNK_SIZES = (1, 2, 3)
86+
# const FINITEDIFF_ERROR = 3e-5
87+
X, Y = rand(13), rand(7)
88+
FINITEDIFF_ERROR = 3e-5
89+
90+
91+
#####
92+
##### ConfusionTest
93+
#####
94+
95+
96+
#####
97+
##### DerivativeTest
98+
#####
99+
100+
@testset verbose=true "DerivativeTest" begin
101+
102+
x = 1
103+
104+
@testset "scalar derivative of DiffTests.$f" for f in DiffTests.NUMBER_TO_NUMBER_FUNCS
105+
v = f(x)
106+
d = ForwardDiff.derivative(f, x)
107+
# @test isapprox(d, Calculus.derivative(f, x), atol=FINITEDIFF_ERROR)
108+
109+
@test d fwd_derivative(f, x) broken=(f==DiffTests.num2num_4)
110+
@test d rev_derivative(f, x) broken=(f==DiffTests.num2num_4)
111+
end
112+
113+
@testset "array derivative of DiffTests.$f" for f in DiffTests.NUMBER_TO_ARRAY_FUNCS
114+
v = f(x)
115+
d = ForwardDiff.derivative(f, x)
116+
# @test isapprox(d, Calculus.derivative(f, x), atol=FINITEDIFF_ERROR)
117+
118+
@test d fwd_derivative(f, x)
119+
@test d rev_derivative(f, x)
120+
end
121+
122+
@testset "exponential function at base zero" for derivative in DERIVATIVES
123+
@test (x -> derivative(y -> x^y, -0.5))(0.0) === -Inf
124+
@test (x -> derivative(y -> x^y, 0.0))(0.0) === -Inf
125+
@test (x -> derivative(y -> x^y, 0.5))(0.0) === 0.0
126+
@test (x -> derivative(y -> x^y, 1.5))(0.0) === 0.0
127+
end
128+
129+
end
130+
131+
#####
132+
##### GradientTest
133+
#####
134+
135+
@testset verbose=true "GradientTest" begin
136+
137+
@testset "hardcoded rosenbrock gradient" begin
138+
f = DiffTests.rosenbrock_1
139+
x = [0.1, 0.2, 0.3]
140+
v = f(x)
141+
g = [-9.4, 15.6, 52.0]
142+
143+
@test g ForwardDiff.gradient(f, x)
144+
@test_broken g fwd_gradient(f, x)
145+
@test g rev_gradient(f, x)
146+
end
147+
148+
@testset "gradient of DiffTests.$f" for f in DiffTests.VECTOR_TO_NUMBER_FUNCS
149+
v = f(X)
150+
g = ForwardDiff.gradient(f, X)
151+
# @test isapprox(g, Calculus.gradient(f, X), atol=FINITEDIFF_ERROR)
152+
153+
@test g fwd_gradient(f, X)
154+
@test g rev_gradient(f, X)
155+
end
156+
157+
@testset "exponential function at base zero: $gradient" for gradient in GRADIENTS
158+
@test isequal(gradient(t -> t[1]^t[2], [0.0, -0.5]), [NaN, NaN])
159+
@test isequal(gradient(t -> t[1]^t[2], [0.0, 0.0]), [NaN, NaN])
160+
@test isequal(gradient(t -> t[1]^t[2], [0.0, 0.5]), [Inf, NaN])
161+
@test isequal(gradient(t -> t[1]^t[2], [0.0, 1.5]), [0.0, 0.0])
162+
end
163+
164+
@testset "chunk size zero - issue 399: $gradient" for gradient in GRADIENTS
165+
f_const(x) = 1.0
166+
g_grad_const = x -> gradient(f_const, x)
167+
@test g_grad_const([1.0]) == [0.0]
168+
@test isempty(g_grad_const(zeros(Float64, 0)))
169+
end
170+
171+
@testset "ArithmeticStyle: $gradient" for gradient in GRADIENTS
172+
function f(p)
173+
sum(collect(0.0:p[1]:p[2]))
174+
end
175+
@test gradient(f, [0.2,25.0]) == [7875.0, 0.0]
176+
end
177+
178+
end
179+
180+
#####
181+
##### HessianTest
182+
#####
183+
184+
@testset verbose=true "HessianTest" begin
185+
186+
@testset "hardcoded rosenbrock hessian" begin
187+
188+
f = DiffTests.rosenbrock_1
189+
x = [0.1, 0.2, 0.3]
190+
v = f(x)
191+
g = [-9.4, 15.6, 52.0]
192+
h = [-66.0 -40.0 0.0;
193+
-40.0 130.0 -80.0;
194+
0.0 -80.0 200.0]
195+
196+
@test isapprox(h, ForwardDiff.hessian(f, x))
197+
198+
@test_skip h fwd_hessian(f, x)
199+
@test_broken h rev_hessian(f, x) # Control flow support not fully implemented yet for higher-order reverse mode
200+
@test_skip h rev_fwd_hessian(f, x)
201+
@test_skip h fwd_rev_hessian(f, x)
202+
end
203+
204+
@testset "hessians for DiffTests.$f" for f in DiffTests.VECTOR_TO_NUMBER_FUNCS
205+
v = f(X)
206+
g = ForwardDiff.gradient(f, X)
207+
h = ForwardDiff.hessian(f, X)
208+
209+
@test_broken g rev_gradient(f, x)
210+
@test_broken h rev_hessian(f, x)
211+
end
212+
213+
end
214+
215+
#####
216+
##### JacobianTest
217+
#####

test/runtests.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
using Diffractor
22
using Test
33

4+
@testset verbose=true "from ForwardDiff.jl" begin
5+
include("forwarddiff.jl")
6+
end
47
@testset verbose=true "ChainRules integration.jl" begin
58
include("chainrules.jl")
69
end

0 commit comments

Comments
 (0)