Skip to content

Commit 36f3ce7

Browse files
authored
RFC: add GPU testing (#646)
* GPU testing * better macro, use registered JLArrays * tidy * more tests * change macro to record source location * fix broken, macro now points to test/rulesets/Base/arraymath.jl:67 * comments * mark a test broken on 1.6 * rm CUDA dep
1 parent f9cf6f4 commit 36f3ce7

File tree

8 files changed

+214
-92
lines changed

8 files changed

+214
-92
lines changed

Project.toml

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
name = "ChainRules"
22
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
3-
version = "1.39.2"
3+
version = "1.40.0"
44

55
[deps]
66
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
77
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
88
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
9+
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
910
IrrationalConstants = "92d709cd-6900-40b7-9082-c6be49f344b6"
1011
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1112
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
@@ -18,19 +19,23 @@ ChainRulesCore = "1.15.3"
1819
ChainRulesTestUtils = "1.5"
1920
Compat = "3.42.0, 4"
2021
FiniteDifferences = "0.12.20"
22+
GPUArraysCore = "0.1.0"
2123
IrrationalConstants = "0.1.1"
24+
JLArrays = "0.1"
2225
JuliaInterpreter = "0.8,0.9"
2326
RealDot = "0.1"
2427
StaticArrays = "1.2"
2528
julia = "1.6"
2629

2730
[extras]
31+
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
2832
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
2933
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
34+
JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb"
3035
JuliaInterpreter = "aa1ae85d-cabe-5617-a682-6adf51b2e16a"
3136
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
3237
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
3338
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
3439

3540
[targets]
36-
test = ["ChainRulesTestUtils", "FiniteDifferences", "JuliaInterpreter", "Random", "StaticArrays", "Test"]
41+
test = ["Adapt", "ChainRulesTestUtils", "FiniteDifferences", "JLArrays", "JuliaInterpreter", "Random", "StaticArrays", "Test"]

test/rulesets/Base/array.jl

Lines changed: 27 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -68,12 +68,12 @@ end
6868

6969
@testset "reshape" begin
7070
# Forward
71-
test_frule(reshape, rand(4, 3), 2, :)
71+
@gpu test_frule(reshape, rand(4, 3), 2, :)
7272
test_frule(reshape, rand(4, 3), axes(rand(6, 2)))
7373
@test_skip test_frule(reshape, Diagonal(rand(4)), 2, :) # https://github.com/JuliaDiff/ChainRulesTestUtils.jl/issues/239
7474

7575
# Reverse
76-
test_rrule(reshape, rand(4, 5), (2, 10))
76+
@gpu test_rrule(reshape, rand(4, 5), (2, 10))
7777
test_rrule(reshape, rand(4, 5), 2, 10)
7878
test_rrule(reshape, rand(4, 5), 2, :)
7979
test_rrule(reshape, rand(4, 5), axes(rand(10, 2)))
@@ -98,14 +98,14 @@ end
9898

9999
@testset "permutedims + PermutedDimsArray" begin
100100
# Forward
101-
test_frule(permutedims, rand(5))
102-
test_frule(permutedims, rand(3, 4), (2, 1))
101+
@gpu test_frule(permutedims, rand(5))
102+
@gpu test_frule(permutedims, rand(3, 4), (2, 1))
103103
test_frule(permutedims!, rand(4,3), rand(3, 4), (2, 1))
104104
test_frule(PermutedDimsArray, rand(3, 4, 5), (3, 1, 2))
105105

106106
# Reverse
107-
test_rrule(permutedims, rand(5))
108-
test_rrule(permutedims, rand(3, 4), (2, 1))
107+
@gpu test_rrule(permutedims, rand(5))
108+
@gpu test_rrule(permutedims, rand(3, 4), (2, 1))
109109
test_rrule(permutedims, Diagonal(rand(5)), (2, 1))
110110
# Note BTW that permutedims(Diagonal(rand(5))) does not use the rule at all
111111

@@ -127,12 +127,12 @@ end
127127
test_rrule(repeat, rand(4, ))
128128
test_rrule(repeat, rand(4, 5))
129129
test_rrule(repeat, rand(4, 5); fkwargs = (outer=(1,2),))
130-
test_rrule(repeat, rand(4, 5); fkwargs = (inner=(1,2), outer=(1,3)))
131-
test_rrule(repeat, rand(4, 5); fkwargs = (outer=2,))
130+
@gpu_broken test_rrule(repeat, rand(4, 5); fkwargs = (inner=(1,2), outer=(1,3)))
131+
@gpu_broken test_rrule(repeat, rand(4, 5); fkwargs = (outer=2,))
132132

133-
test_rrule(repeat, rand(4, ), 2)
134-
test_rrule(repeat, rand(4, 5), 2)
135-
test_rrule(repeat, rand(4, 5), 2, 3)
133+
@gpu test_rrule(repeat, rand(4, ), 2)
134+
@gpu test_rrule(repeat, rand(4, 5), 2)
135+
@gpu test_rrule(repeat, rand(4, 5), 2, 3)
136136
test_rrule(repeat, rand(1,2,3), 2,3,4; check_inferred=VERSION>v"1.6")
137137
test_rrule(repeat, rand(0,2,3), 2,0,4; check_inferred=VERSION>v"1.6")
138138
test_rrule(repeat, rand(1,1,1,1), 2,3,4,5; check_inferred=VERSION>v"1.6")
@@ -153,16 +153,16 @@ end
153153

154154
@test rrule(repeat, [1,2,3], 4)[2](ones(12))[2] == [4,4,4]
155155
@test rrule(repeat, [1,2,3], outer=4)[2](ones(12))[2] == [4,4,4]
156-
157156
end
158157

159158
@testset "hcat" begin
160159
# forward
161-
test_frule(hcat, randn(3, 2), randn(3))
162-
test_frule(hcat, randn(), randn(1,3))
160+
@gpu test_frule(hcat, randn(3, 2), randn(3))
161+
@gpu test_frule(hcat, randn(), randn(1,3))
163162

164163
# reverse
165-
test_rrule(hcat, randn(3, 2), randn(3), randn(3, 3))
164+
@gpu test_rrule(hcat, randn(3, 2), randn(3), randn(3, 3))
165+
@gpu test_rrule(hcat, rand(1,2), rand(), rand(1,3))
166166
test_rrule(hcat, rand(), rand(1,2), rand(1,2,1))
167167
test_rrule(hcat, rand(3,1,1,2), rand(3,3,1,2))
168168

@@ -194,13 +194,14 @@ end
194194
end
195195

196196
@testset "vcat" begin
197-
198197
# forward
199198
test_frule(vcat, randn(), randn(3), rand())
200-
test_frule(vcat, randn(3, 1), randn(3))
199+
@gpu test_frule(vcat, randn(3), rand(), randn(3))
200+
@gpu test_frule(vcat, randn(3, 1), randn(3))
201201

202202
# reverse
203-
test_rrule(vcat, randn(2, 4), randn(1, 4), randn(3, 4))
203+
@gpu test_rrule(vcat, randn(3), rand(), randn(3))
204+
@gpu test_rrule(vcat, randn(2, 4), randn(1, 4), randn(3, 4))
204205
test_rrule(vcat, rand(), rand())
205206
test_rrule(vcat, rand(), rand(3), rand(3,1,1))
206207
test_rrule(vcat, rand(3,1,2), rand(4,1,2))
@@ -230,8 +231,8 @@ end
230231
test_frule(cat, rand(), rand(2,3); fkwargs=(dims=(1,2),))
231232

232233
# reverse
233-
test_rrule(cat, rand(2, 4), rand(1, 4); fkwargs=(dims=1,))
234-
test_rrule(cat, rand(2, 4), rand(2); fkwargs=(dims=Val(2),))
234+
@gpu test_rrule(cat, rand(2, 4), rand(1, 4); fkwargs=(dims=1,))
235+
@gpu test_rrule(cat, rand(2, 4), rand(2); fkwargs=(dims=Val(2),))
235236
test_rrule(cat, rand(), rand(2, 3); fkwargs=(dims=[1,2],))
236237
test_rrule(cat, rand(1), rand(3, 2, 1); fkwargs=(dims=(1,2),), check_inferred=false) # infers Tuple{Zero, Vector{Float64}, Any}
237238

@@ -263,7 +264,7 @@ end
263264
end
264265
@testset "Array" begin
265266
# Forward
266-
test_frule(reverse, rand(5))
267+
@gpu_broken test_frule(reverse, rand(5))
267268
test_frule(reverse, rand(5), 2, 4)
268269
test_frule(reverse, rand(5), fkwargs=(dims=1,))
269270
test_frule(reverse, rand(3,4), fkwargs=(dims=2,))
@@ -275,7 +276,7 @@ end
275276
test_frule(reverse!, rand(3,4), fkwargs=(dims=2,))
276277

277278
# Reverse
278-
test_rrule(reverse, rand(5))
279+
@gpu_broken test_rrule(reverse, rand(5))
279280
test_rrule(reverse, rand(5), 2, 4)
280281
test_rrule(reverse, rand(5), fkwargs=(dims=1,))
281282

@@ -293,15 +294,15 @@ end
293294

294295
@testset "circshift" begin
295296
# Forward
296-
test_frule(circshift, rand(10), 1)
297+
@gpu test_frule(circshift, rand(10), 1)
297298
test_frule(circshift, rand(10), (1,))
298299
test_frule(circshift, rand(3,4), (-7,2))
299300

300301
test_frule(circshift!, rand(10), rand(10), 1)
301302
test_frule(circshift!, rand(3,4), rand(3,4), (-7,2))
302303

303304
# Reverse
304-
test_rrule(circshift, rand(10), 1)
305+
@gpu test_rrule(circshift, rand(10), 1)
305306
test_rrule(circshift, rand(10) .+ im, -2)
306307
test_rrule(circshift, rand(10), (1,))
307308
test_rrule(circshift, rand(3,4), (-7,2))
@@ -379,14 +380,14 @@ end
379380
# Forward
380381
test_frule(imum, rand(10))
381382
test_frule(imum, rand(3,4))
382-
test_frule(imum, rand(3,4), fkwargs=(dims=1,))
383+
@gpu_broken test_frule(imum, rand(3,4), fkwargs=(dims=1,))
383384
test_frule(imum, [rand(2) for _ in 1:3])
384385
test_frule(imum, [rand(2) for _ in 1:3, _ in 1:4]; fkwargs=(dims=1,))
385386

386387
# Reverse
387388
test_rrule(imum, rand(10))
388389
test_rrule(imum, rand(3,4))
389-
test_rrule(imum, rand(3,4), fkwargs=(dims=1,))
390+
@gpu_broken test_rrule(imum, rand(3,4), fkwargs=(dims=1,))
390391
test_rrule(imum, rand(3,4,5), fkwargs=(dims=(1,3),))
391392

392393
# Arrays of arrays

test/rulesets/Base/arraymath.jl

Lines changed: 38 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,26 @@
11
@testset "arraymath.jl" begin
22
@testset "inv(::Matrix{$T})" for T in (Float64, ComplexF64)
33
B = generate_well_conditioned_matrix(T, 3)
4-
test_frule(inv, B)
5-
test_rrule(inv, B)
4+
if VERSION >= v"1.7"
5+
@gpu test_frule(inv, B)
6+
@gpu test_rrule(inv, B)
7+
else
8+
@gpu_broken test_frule(inv, B)
9+
@gpu_broken test_rrule(inv, B)
10+
end
611
end
712

813
@testset "*: $T" for T in (Float64, ComplexF64)
914
(a) = round.(5*randn(T, a)) # Helper to generate nice random values
1015
(a, b) = ((a, b)) # matrix
1116
() = only((())) # scalar
1217

13-
@testset "Scalar-Array $dims" for dims in ((3,), (5,4), (2, 3, 4, 5))
14-
test_frule(*, (), (dims))
15-
test_frule(*, (dims), ())
18+
@testset "Scalar-Array $dims" for dims in ((3,), (2, 3, 4))
19+
@gpu test_frule(*, (), (dims))
20+
@gpu test_frule(*, (dims), ())
1621

17-
test_rrule(*, (), (dims))
18-
test_rrule(*, (dims), ())
22+
@gpu test_rrule(*, (), (dims))
23+
@gpu test_rrule(*, (dims), ())
1924
end
2025

2126
@testset "AbstractMatrix-AbstractVector n=$n, m=$m" for n in (2, 3), m in (4, 5)
@@ -60,41 +65,39 @@
6065

6166
@testset "Diagonal" begin
6267
# fwd
63-
test_frule(*, Diagonal([1.0, 2.0, 3.0]), Diagonal([4.0, 5.0, 6.0]))
64-
test_frule(*, Diagonal([1.0, 2.0, 3.0]), rand(3))
68+
@gpu test_frule(*, Diagonal([1.0, 2.0, 3.0]), Diagonal([4.0, 5.0, 6.0]))
69+
@gpu test_frule(*, Diagonal([1.0, 2.0, 3.0]), rand(3))
6570

6671
# rev
67-
test_rrule(*, Diagonal([1.0, 2.0, 3.0]), Diagonal([4.0, 5.0, 6.0]))
68-
test_rrule(*, Diagonal([1.0, 2.0, 3.0]), rand(3))
72+
@gpu test_rrule(*, Diagonal([1.0, 2.0, 3.0]), Diagonal([4.0, 5.0, 6.0]))
73+
@gpu test_rrule(*, Diagonal([1.0, 2.0, 3.0]), rand(3))
6974

7075
# Needs to not try and inplace, as `mul!` will do wrong.
7176
# see https://github.com/JuliaDiff/ChainRulesCore.jl/issues/411
72-
test_rrule(*, Diagonal([1.0, 2.0, 3.0]), rand(3,3))
77+
@gpu test_rrule(*, Diagonal([1.0, 2.0, 3.0]), rand(3,3))
7378
end
7479

75-
@testset "Covector * Vector n=$n" for n in (3, 5)
76-
@testset "$f" for f in (adjoint, transpose)
77-
# This should be same as dot product and give a scalar
78-
test_rrule(*, f((n)) f((n)), (n))
79-
end
80+
@testset "$adj * Vector" for adj in (adjoint, transpose)
81+
# This should be same as dot product and give a scalar
82+
test_rrule(*, adj((5)) adj((5)), (5))
8083
end
8184
end
8285

8386
@testset "muladd: $T" for T in (Float64, ComplexF64)
84-
@testset "add $(typeof(z))" for z in [rand(T), rand(T, 3), rand(T, 3, 3), false]
87+
@testset "add $(typeof(z))" for z in [rand(), rand(T, 3), rand(T, 3, 3), false]
8588
@testset "forward mode" begin
86-
test_frule(muladd, rand(T, 3, 5), rand(T, 5, 3), z)
89+
@gpu test_frule(muladd, rand(T, 3, 5), rand(T, 5, 3), z)
8790
end
8891
@testset "matrix * matrix" begin
8992
A = rand(T, 3, 3)
9093
B = rand(T, 3, 3)
91-
test_rrule(muladd, A, B, z)
92-
test_rrule(muladd, A', B, z)
93-
test_rrule(muladd, A , B', z)
94+
@gpu test_rrule(muladd, A, B, z)
95+
@gpu test_rrule(muladd, A', B, z)
96+
@gpu test_rrule(muladd, A , B', z)
9497

9598
C = rand(T, 3, 5)
9699
D = rand(T, 5, 3)
97-
test_rrule(muladd, C, D, z)
100+
@gpu test_rrule(muladd, C, D, z)
98101
end
99102
if ndims(z) <= 1
100103
@testset "matrix * vector" begin
@@ -181,32 +184,32 @@
181184
@testset "/ and \\ Scalar-AbstractArray" begin
182185
A = round.(10 .* randn(3, 4, 5), digits=1)
183186
# fwd
184-
test_frule(/, A, 7.2)
185-
test_frule(\, 7.2, A)
187+
@gpu test_frule(/, A, 7.2)
188+
@gpu test_frule(\, 7.2, A)
186189
# rev
187-
test_rrule(/, A, 7.2)
188-
test_rrule(\, 7.2, A)
190+
@gpu test_rrule(/, A, 7.2)
191+
@gpu test_rrule(\, 7.2, A)
189192

190193
C = round.(10 .* randn(6) .+ im .* 10 .* randn(6), digits=1)
191-
test_rrule(/, C, 7.2+8.3im)
192-
test_rrule(\, 7.2+8.3im, C)
194+
@gpu test_rrule(/, C, 7.2+8.3im)
195+
@gpu test_rrule(\, 7.2+8.3im, C)
193196
end
194197

195198
@testset "negation" begin
196199
A = randn(4, 4)
197200
= randn(4, 4)
198201
# fwd
199-
test_frule(-, A)
202+
@gpu test_frule(-, A)
200203
# rev
201-
test_rrule(-, A)
202-
test_rrule(-, Diagonal(A); output_tangent=Diagonal(Ā))
204+
@gpu test_rrule(-, A)
205+
@gpu test_rrule(-, Diagonal(A); output_tangent=Diagonal(Ā))
203206
end
204207

205208
@testset "addition" begin
206209
# fwd
207-
test_frule(+, randn(2), randn(2), randn(2))
210+
@gpu test_frule(+, randn(2), randn(2), randn(2))
208211
# rev
209-
test_rrule(+, randn(4, 4), randn(4, 4), randn(4, 4))
210-
test_rrule(+, randn(3), randn(3,1), randn(3,1,1))
212+
@gpu test_rrule(+, randn(4, 4), randn(4, 4), randn(4, 4))
213+
@gpu test_rrule(+, randn(3), randn(3,1), randn(3,1,1))
211214
end
212215
end

test/rulesets/Base/mapreduce.jl

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,12 @@ const CFG = ChainRulesTestUtils.ADviaRuleConfig()
1515
end
1616
@testset "sum(x; dims=$dims)" for dims in (:, 2, (1,3))
1717
# Forward
18-
test_frule(sum, rand(5); fkwargs=(;dims=dims))
19-
test_frule(sum, rand(ComplexF64, 2,3,4); fkwargs=(;dims=dims))
18+
@gpu test_frule(sum, rand(5); fkwargs=(;dims=dims))
19+
@gpu test_frule(sum, rand(ComplexF64, 2,3,4); fkwargs=(;dims=dims))
2020

2121
# Reverse
22-
test_rrule(sum, rand(5); fkwargs=(;dims=dims))
23-
test_rrule(sum, rand(ComplexF64, 2,3,4); fkwargs=(;dims=dims))
22+
@gpu test_rrule(sum, rand(5); fkwargs=(;dims=dims))
23+
@gpu test_rrule(sum, rand(ComplexF64, 2,3,4); fkwargs=(;dims=dims))
2424

2525
# Structured matrices
2626
test_rrule(sum, rand(5)'; fkwargs=(;dims=dims))
@@ -58,8 +58,8 @@ const CFG = ChainRulesTestUtils.ADviaRuleConfig()
5858
@testset "dims = $dims" for dims in (:, 1)
5959
@testset "Array{$N, $T}" for N in eachindex(sizes), T in (Float64, ComplexF64)
6060
x = randn(T, sizes[1:N]...)
61-
test_frule(sum, abs2, x; fkwargs=(;dims=dims))
62-
test_rrule(sum, abs2, x; fkwargs=(;dims=dims))
61+
@gpu test_frule(sum, abs2, x; fkwargs=(;dims=dims))
62+
@gpu test_rrule(sum, abs2, x; fkwargs=(;dims=dims))
6363
end
6464

6565
# Boolean -- via @non_differentiable, test that this isn't ambiguous
@@ -156,10 +156,10 @@ const CFG = ChainRulesTestUtils.ADviaRuleConfig()
156156
((3,4), 1), ((3,4), 2), ((3,4), :), ((3,4), [1,2]),
157157
((3,4,1), 1), ((3,2,2), 3), ((3,2,2), 2:3),
158158
]
159-
x = randn(T, sz)
160-
test_rrule(prod, x; fkwargs=(dims=dims,), check_inferred=true)
159+
x = rand(T, sz) .+ 1 # no zeros
160+
@gpu test_rrule(prod, x; fkwargs=(dims=dims,), check_inferred=true)
161161
x[1] = 0
162-
test_rrule(prod, x; fkwargs=(dims=dims,), check_inferred=true)
162+
@gpu_broken test_rrule(prod, x; fkwargs=(dims=dims,), check_inferred=true)
163163
x[5] = 0
164164
test_rrule(prod, x; fkwargs=(dims=dims,), check_inferred=true)
165165
x[3] = x[7] = 0 # two zeros along some slice, for any dims

0 commit comments

Comments
 (0)