Skip to content

Commit fc6b7fc

Browse files
committed
start adding gradcheck tests from zygote
1 parent 6bff01d commit fc6b7fc

File tree

2 files changed

+166
-0
lines changed

2 files changed

+166
-0
lines changed

test/gradcheck.jl

Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,165 @@
1+
# This file contains a selection of tests from Zygote's "gradcheck.jl",
2+
# dealing with Base and standard library functions. Many of these use rules
3+
# which have their own more exhaustive tests in ChainRules.
4+
5+
# Tests for packages (Distances, LogExpFunctions, AbstractFFTs, FillArrays) are not included.
6+
7+
# Ideally this would be extended to take `gradient` both forward and reverse,
8+
# and `jacobicheck` including 2nd derivatives, for every testset. But not yet.
9+
10+
using Test
11+
using ChainRulesCore
12+
using Diffractor
13+
using FiniteDifferences
14+
using LinearAlgebra
15+
16+
#####
17+
##### Zygote/test/gradcheck.jl : setup
18+
#####
19+
20+
n_grad(f, x::Real) = (central_fdm(5,1)(f,x),)
21+
n_grad(f, x::AbstractArray{<:Real}) = FiniteDifferences.grad(central_fdm(5,1), f, float(x))
22+
n_grad(f, xs::Vararg{Any,N}) where {N} = ntuple(N) do i
23+
n_grad(x -> f(ntuple(j -> j==i ? x : xs[j], N)...), xs[i])[1]
24+
end
25+
26+
# check gradients via finite differencing
27+
function gradcheck(f, xs::AbstractArray...)
28+
gs = unthunk.(gradient(f, xs...))
29+
all(isapprox.(gs, n_grad(f, xs...), rtol=1e-5, atol=1e-5))
30+
end
31+
gradcheck(f, dims...) = gradcheck(f, rand.(Float64, dims)...)
32+
# @test gradcheck(sqrt, 3.14) # given number
33+
@test gradcheck(sum, randn(10)) # given array
34+
@test gradcheck(dot, randn(3), rand(3)) # given multiple vectors
35+
@test gradcheck(dot, 3, 3) # given multiple random vectors
36+
37+
jacobicheck(f, xs::AbstractArray...) = f(xs...) isa Number ? gradcheck(f, xs...) :
38+
gradcheck((xs...) -> sum(sin, f(xs...)), xs...)
39+
jacobicheck(f, dims...) = jacobicheck(f, randn.(Float64, dims)...)
40+
@test jacobicheck(identity, [1,2,3]) # one given array
41+
@test jacobicheck(sum, [1,2,3]) # fallback to gradcheck
42+
@test jacobicheck(identity, (4,5)) # one random matrix
43+
@test jacobicheck(+, 3, 3) # two random vectors
44+
45+
46+
#####
47+
##### Zygote/test/gradcheck.jl : Base
48+
#####
49+
50+
@testset "power" begin
51+
@test gradient(x -> x^2, -2) == (-4,) # literal_pow
52+
@test gradient(x -> x^10, -1.0) == (-10,)
53+
_pow = 10
54+
@test gradient(x -> x^_pow, -1.0) == (-_pow,)
55+
@test gradient(p -> real(2^p), 2)[1] 4*log(2)
56+
57+
@test gradient(xs ->sum(xs .^ 2), [2, -1]) == ([4, -2],)
58+
@test gradient(xs ->sum(xs .^ 10), [3, -1]) == ([10*3^9, -10],)
59+
@test gradient(xs ->sum(xs .^ _pow), [4, -1]) == ([_pow*4^9, -10],)
60+
61+
@test gradient(x -> real((1+3im) * x^2), 5+7im) == (-32 - 44im,)
62+
@test gradient(p -> real((1+3im) * (5+7im)^p), 2)[1] real((-234 + 2im)*log(5 - 7im))
63+
# D[(1+3I)x^p, p] /. {x->5+7I, p->2} // Conjugate
64+
end
65+
66+
@testset "jacobian" begin
67+
@test jacobicheck((x, W, b) -> identity.(W*x .+ b), 5, (2,5), 2)
68+
@test jacobicheck((x, W, b) -> identity.(W*x .+ b), (5,3), (2,5), 2)
69+
70+
71+
@test jacobicheck((x, W, b) -> tanh.(W*x .+ b), 5, (2,5), 2)
72+
@test jacobicheck((x, W, b) -> tanh.(W*x .+ b), (5,3), (2,5), 2)
73+
74+
@test jacobicheck((w, x) -> w'*x, randn(10, 2), randn(10))
75+
@test jacobicheck((w, x) -> Adjoint(w)*x, randn(10, 2), randn(10))
76+
@test jacobicheck((w, x) -> transpose(w)*x, randn(5,5), randn(5,5))
77+
@test jacobicheck((w, x) -> Transpose(w)*x, randn(5,5), randn(5,5))
78+
79+
80+
# FIXME: fail with:
81+
# MethodError: no method matching isapprox(::Tangent{Adjoint{Float64, Matrix{Float64}}, @NamedTuple{parent::Matrix{Float64}}}, ::Adjoint{Float64, Matrix{Float64}}; rtol::Float64, atol::Float64)
82+
@test_broken jacobicheck((w, x) -> parent(w)*x, randn(5,5)', randn(5,5))
83+
@test_broken jacobicheck((w, x) -> parent(w)*x, transpose(randn(5,5)), randn(5,5))
84+
end
85+
86+
@testset "sum, prod" begin
87+
@test gradcheck(x -> sum(abs2, x), randn(4, 3, 2))
88+
@test gradcheck(x -> sum(x[i] for i in 1:length(x)), randn(10))
89+
@test gradcheck(x -> sum(i->x[i], 1:length(x)), randn(10)) # issue #231
90+
@test gradcheck(x -> sum((i->x[i]).(1:length(x))), randn(10))
91+
@test gradcheck(X -> sum(x -> x^2, X), randn(10))
92+
93+
# FIXME: fail with
94+
# MethodError: no method matching copy(::Nothing)
95+
@test_broken jacobicheck(x -> sum(x, dims = (2, 3)), (3,4,5))
96+
@test_broken jacobicheck(x -> sum(abs2, x; dims=1), randn(4, 3, 2))
97+
@test_broken gradcheck(X -> sum(sum(x -> x^2, X; dims=1)), randn(10)) # issue #681
98+
99+
# Non-differentiable sum of booleans
100+
@test gradient(sum, [true, false, true]) == (NoTangent(),)
101+
@test gradient(x->sum(x .== 0.0), [1.2, 0.2, 0.0, -1.1, 100.0]) == (NoTangent(),)
102+
103+
# https://github.com/FluxML/Zygote.jl/issues/314
104+
@test gradient((x,y) -> sum(yi -> yi*x, y), 1, [1,1]) == (2, [1, 1])
105+
@test gradient((x,y) -> prod(yi -> yi*x, y), 1, [1,1]) == (2, [1, 1])
106+
107+
# FIXME: fail with
108+
# AssertionError: Base.issingletontype(typeof(f))
109+
@test_broken gradient((x,y) -> sum(map(yi -> yi*x, y)), 1, [1,1]) == (2, [1, 1])
110+
@test_broken gradient((x,y) -> prod(map(yi -> yi*x, y)), 1, [1,1]) == (2, [1, 1])
111+
112+
@test gradcheck(x -> prod(x), (3,4))
113+
@test gradient(x -> prod(x), (1,2,3))[1] == (6,3,2)
114+
115+
# FIXME: fail with
116+
# MethodError: no method matching copy(::Nothing)
117+
@test_broken jacobicheck(x -> prod(x, dims = (2, 3)), (3,4,5))
118+
end
119+
120+
@testset "cumsum" begin
121+
@test jacobicheck(x -> cumsum(x), (4,))
122+
123+
# FIXME: fail with
124+
# TypeError: in typeassert, expected Int64, got a value of type Nothing
125+
@test_broken jacobicheck(x -> cumsum(x, dims=2), (3,4,5))
126+
@test_broken jacobicheck(x -> cumsum(x, dims=3), (3,4)) # trivial
127+
128+
# FIXME: fail with
129+
# MethodError: no method matching copy(::Nothing)
130+
@test_broken jacobicheck(x -> cumsum(x, dims=1), (3,))
131+
132+
# FIXME: fail with
133+
# Rewrite reached intrinsic function bitcast. Missing rule?
134+
@test_broken jacobicheck(x -> cumsum(x, dims=3), (5,)) # trivial
135+
end
136+
137+
138+
# FIXME: complex numbers; put somewhere
139+
@test gradcheck((a,b)->sum(reim(acosh(complex(a[1], b[1])))), [-2.0], [1.0])
140+
141+
# FIXME: include those?
142+
# @testset "println, show, string, etc" begin
143+
# function foo(x)
144+
# Base.show(x)
145+
# Base.print(x)
146+
# Base.print(stdout, x)
147+
# Base.println(x)
148+
# Base.println(stdout, x)
149+
# Core.show(x)
150+
# Core.print(x)
151+
# Core.println(x)
152+
# return x
153+
# end
154+
# gradcheck(f, xs::AbstractArray...) = gradcheck((xs...) -> sum(sin.(f(xs...))), xs...)
155+
# println("The following printout is from testing that `print` doesn't upset gradients:")
156+
# @test gradcheck(foo, [5.0])
157+
#
158+
# function bar(x)
159+
# string(x)
160+
# repr(x)
161+
# return x
162+
# end
163+
# @test gradcheck(bar, [5.0])
164+
# end
165+

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ const bwd = Diffractor.PrimeDerivativeBack
2121
"reverse.jl",
2222
"regression.jl",
2323
"AbstractDifferentiationTests.jl"
24+
"gradcheck.jl"
2425
#"pinn.jl", # Higher order control flow not yet supported (https://github.com/JuliaDiff/Diffractor.jl/issues/24)
2526
)
2627
include(file)

0 commit comments

Comments
 (0)