Skip to content

Commit adb914e

Browse files
authored
Extend the type coverage, and give more control to back-ends. (#373)
1 parent 0ada3f9 commit adb914e

File tree

14 files changed

+219
-189
lines changed

14 files changed

+219
-189
lines changed

test/testsuite.jl

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -43,17 +43,28 @@ function compare(f, AT::Type{<:Array}, xs...; kwargs...)
4343
return true
4444
end
4545

46-
function supported_eltypes()
47-
(Float32, Float64, Int32, Int64, ComplexF32, ComplexF64)
48-
end
46+
# element types that are supported by the array type
47+
supported_eltypes(AT, test) = supported_eltypes(AT)
48+
supported_eltypes(AT) = supported_eltypes()
49+
supported_eltypes() = (Int16, Int32, Int64,
50+
Float16, Float32, Float64,
51+
ComplexF16, ComplexF32, ComplexF64,
52+
Complex{Int16}, Complex{Int32}, Complex{Int64})
53+
54+
# some convenience predicates for filtering test eltypes
55+
isrealtype(T) = T <: Real
56+
iscomplextype(T) = T <: Complex
57+
isfloattype(T) = T <: AbstractFloat || T <: Complex{<:AbstractFloat}
4958

5059
# list of tests
5160
const tests = Dict()
5261
macro testsuite(name, ex)
53-
safe_name = lowercase(replace(name, " "=>"_"))
62+
safe_name = lowercase(replace(replace(name, " "=>"_"), "/"=>"_"))
5463
fn = Symbol("test_$(safe_name)")
5564
quote
56-
$(esc(fn))(AT) = $(esc(ex))(AT)
65+
# the supported element types can be overrided by passing in a different set,
66+
# or by specializing the `supported_eltypes` function on the array type and test.
67+
$(esc(fn))(AT; eltypes=supported_eltypes(AT, $(esc(fn)))) = $(esc(ex))(AT, eltypes)
5768

5869
@assert !haskey(tests, $name)
5970
tests[$name] = $fn

test/testsuite/base.jl

Lines changed: 27 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ function ntuple_closure(ctx, result, ::Val{N}, testval) where N
2525
return
2626
end
2727

28-
@testsuite "base" AT->begin
28+
@testsuite "base" (AT, eltypes)->begin
2929
@testset "copyto!" begin
3030
x = fill(0f0, (10, 10))
3131
y = rand(Float32, (20, 10))
@@ -70,15 +70,15 @@ end
7070
copyto!(a, r1, b, r2)
7171
@test x == Array(a)
7272

73-
x = fill(0., (10,))
74-
y = fill(1, (10,))
73+
x = fill(0f0, (10,))
74+
y = fill(1f0, (10,))
7575
a = AT(x)
7676
b = AT(y)
7777
copyto!(a, b)
78-
@test Float64.(y) == Array(a)
78+
@test Float32.(y) == Array(a)
7979

8080
# wrapped gpu array to wrapped gpu array
81-
x = rand(4, 4)
81+
x = rand(Float32, 4, 4)
8282
a = AT(x)
8383
b = view(a, 2:3, 2:3)
8484
c = AT{eltype(b)}(undef, size(b))
@@ -95,21 +95,23 @@ end
9595

9696
# bug in copyto!
9797
## needless N type parameter
98-
@test compare((x,y)->copyto!(y, selectdim(x, 2, 1)), AT, ones(2,2,2), zeros(2,2))
98+
@test compare((x,y)->copyto!(y, selectdim(x, 2, 1)), AT, ones(Float32, 2, 2, 2), zeros(Float32, 2, 2))
9999
## inability to copyto! smaller destination
100100
## (this was broken on Julia <1.5)
101-
@test compare((x,y)->copyto!(y, selectdim(x, 2, 1)), AT, ones(2,2,2), zeros(3,3))
102-
103-
# mismatched types
104-
let src = rand(Float32, 4)
105-
dst = AT{Float64}(undef, size(src))
106-
copyto!(dst, src)
107-
@test Array(dst) == src
108-
end
109-
let dst = Array{Float64}(undef, 4)
110-
src = AT(rand(Float32, size(dst)))
111-
copyto!(dst, src)
112-
@test Array(src) == dst
101+
@test compare((x,y)->copyto!(y, selectdim(x, 2, 1)), AT, ones(Float32, 2, 2, 2), zeros(Float32, 3, 3))
102+
103+
if (Float32 in eltypes && Float64 in eltypes)
104+
# mismatched types
105+
let src = rand(Float32, 4)
106+
dst = AT{Float64}(undef, size(src))
107+
copyto!(dst, src)
108+
@test Array(dst) == src
109+
end
110+
let dst = Array{Float64}(undef, 4)
111+
src = AT(rand(Float32, size(dst)))
112+
copyto!(dst, src)
113+
@test Array(src) == dst
114+
end
113115
end
114116
end
115117

@@ -123,11 +125,11 @@ end
123125
end
124126

125127
@testset "reshape" begin
126-
@test compare(reshape, AT, rand(10), Ref((10,)))
127-
@test compare(reshape, AT, rand(10), Ref((10,1)))
128-
@test compare(reshape, AT, rand(10), Ref((1,10)))
128+
@test compare(reshape, AT, rand(Float32, 10), Ref((10,)))
129+
@test compare(reshape, AT, rand(Float32, 10), Ref((10,1)))
130+
@test compare(reshape, AT, rand(Float32, 10), Ref((1,10)))
129131

130-
@test_throws Exception reshape(AT(rand(10)), (10,2))
132+
@test_throws Exception reshape(AT(rand(Float32, 10)), (10,2))
131133
end
132134

133135
@testset "reinterpret" begin
@@ -158,7 +160,7 @@ end
158160
AT <: AbstractGPUArray && @testset "cartesian iteration" begin
159161
Ac = rand(Float32, 32, 32)
160162
A = AT(Ac)
161-
result = fill!(copy(A), 0.0)
163+
result = fill!(copy(A), 0.0f0)
162164
gpu_call(cartesian_iter, result, A, size(A))
163165
Array(result) == Ac
164166
end
@@ -188,10 +190,10 @@ end
188190
end
189191

190192
@testset "permutedims" begin
191-
@test compare(x->permutedims(x, [1, 2]), AT, rand(4, 4))
193+
@test compare(x->permutedims(x, [1, 2]), AT, rand(Float32, 4, 4))
192194

193195
inds = rand(1:100, 150, 150)
194-
@test compare(x->permutedims(view(x, inds, :), (3, 2, 1)), AT, rand(100, 100))
196+
@test compare(x->permutedims(view(x, inds, :), (3, 2, 1)), AT, rand(Float32, 100, 100))
195197
end
196198

197199
@testset "circshift" begin

test/testsuite/broadcasting.jl

Lines changed: 34 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
@testsuite "broadcasting" AT->begin
2-
broadcasting(AT)
3-
vec3(AT)
1+
@testsuite "broadcasting" (AT, eltypes)->begin
2+
broadcasting(AT, eltypes)
3+
vec3(AT, eltypes)
44

55
@testset "type instabilities" begin
66
f(x) = x ? 1.0 : 0
@@ -37,8 +37,8 @@ function test_kernel(a::T, b) where T
3737
return c
3838
end
3939

40-
function broadcasting(AT)
41-
for ET in supported_eltypes()
40+
function broadcasting(AT, eltypes)
41+
for ET in eltypes
4242
N = 10
4343
@testset "broadcast $ET" begin
4444
@testset "RefValue" begin
@@ -91,7 +91,8 @@ function broadcasting(AT)
9191
# since GPUArrays adds some arguments to the function, it becomes longer longer, hitting the 12
9292
# so this wont fix for now
9393
@test compare(AT, rand(ET, dim), rand(ET, dim), rand(ET, dim), rand(ET, dim), rand(ET, dim), rand(ET, dim)) do a1, a2, a3, a4, a5, a6
94-
@. a1 = a2 + (1.2) *((1.3)*a3 + (1.4)*a4 + (1.5)*a5 + (1.6)*a6)
94+
c1, c2, c3, c4, c5 = ET(1.2), ET(1.3), ET(1.4), ET(1.5), ET(1.6)
95+
@. a1 = a2 + c1 * (c2 * a3 + c3 * a4 + c4 * a5 + c5 * a6)
9596
end
9697

9798
@test compare(AT, rand(ET, dim), rand(ET, dim), rand(ET, dim), rand(ET, dim)) do u, uprev, duprev, ku
@@ -110,6 +111,14 @@ function broadcasting(AT)
110111
dt = ET(1)
111112
@. utilde = dt*(btilde1*k1 + btilde2*k2 + btilde3*k3 + btilde4*k4)
112113
end
114+
115+
@testset "0D" begin
116+
x = AT{ET}(undef)
117+
x .= ET(1)
118+
@test collect(x)[] == ET(1)
119+
x /= ET(2)
120+
@test collect(x)[] == ET(0.5)
121+
end
113122
end
114123

115124
@test compare((x) -> fill!(x, 1), AT, rand(ET, 3,3))
@@ -127,59 +136,51 @@ function broadcasting(AT)
127136
end
128137

129138
@testset "map! $ET" begin
130-
@test compare(AT, rand(2,2), rand(2,2)) do x,y
139+
@test compare(AT, rand(ET, 2,2), rand(ET, 2,2)) do x,y
131140
map!(+, x, y)
132141
end
133-
@test compare(AT, rand(2), rand(2,2)) do x,y
142+
@test compare(AT, rand(ET, 2), rand(ET, 2,2)) do x,y
134143
map!(+, x, y)
135144
end
136-
@test compare(AT, rand(2,2), rand(2)) do x,y
145+
@test compare(AT, rand(ET, 2,2), rand(ET, 2)) do x,y
137146
map!(+, x, y)
138147
end
139148
end
140149

141150
@testset "map $ET" begin
142-
@test compare(AT, rand(2,2), rand(2,2)) do x,y
151+
@test compare(AT, rand(ET, 2,2), rand(ET, 2,2)) do x,y
143152
map(+, x, y)
144153
end
145-
@test compare(AT, rand(2), rand(2,2)) do x,y
154+
@test compare(AT, rand(ET, 2), rand(ET, 2,2)) do x,y
146155
map(+, x, y)
147156
end
148-
@test compare(AT, rand(2,2), rand(2)) do x,y
157+
@test compare(AT, rand(ET, 2,2), rand(ET, 2)) do x,y
149158
map(+, x, y)
150159
end
151160
end
152-
end
153161

154-
@testset "0D" begin
155-
x = AT{Float64}(undef)
156-
x .= 1
157-
@test collect(x)[] == 1
158-
x /= 2
159-
@test collect(x)[] == 0.5
160-
end
161-
162-
@testset "Ref" begin
163-
# as first arg, 0d broadcast
164-
@test compare(x->getindex.(Ref(x),1), AT, [0])
162+
@testset "Ref" begin
163+
# as first arg, 0d broadcast
164+
@test compare(x->getindex.(Ref(x), 1), AT, ET[0])
165165

166-
void_setindex!(args...) = (setindex!(args...); return)
167-
@test compare(x->(void_setindex!.(Ref(x),1); x), AT, [0])
166+
void_setindex!(args...) = (setindex!(args...); return)
167+
@test compare(x->(void_setindex!.(Ref(x), ET(1)); x), AT, ET[0])
168168

169-
# regular broadcast
170-
a = AT(rand(10))
171-
b = AT(rand(10))
172-
cpy(i,a,b) = (a[i] = b[i]; return)
173-
cpy.(1:10, Ref(a), Ref(b))
174-
@test Array(a) == Array(b)
169+
# regular broadcast
170+
a = AT(rand(ET, 10))
171+
b = AT(rand(ET, 10))
172+
cpy(i,a,b) = (a[i] = b[i]; return)
173+
cpy.(1:10, Ref(a), Ref(b))
174+
@test Array(a) == Array(b)
175+
end
175176
end
176177

177178
@testset "stackoverflow in copy(::Broadcast)" begin
178179
copy(Base.broadcasted(identity, AT(Int[])))
179180
end
180181
end
181182

182-
function vec3(AT)
183+
function vec3(AT, eltypes)
183184
@testset "vec 3" begin
184185
N = 20
185186

test/testsuite/construction.jl

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
@testsuite "construct/direct" AT->begin
2-
for T in supported_eltypes()
1+
@testsuite "construct/direct" (AT, eltypes)->begin
2+
for T in eltypes
33
B = AT{T}(undef, 10)
44
@test B isa AT{T,1}
55
@test size(B) == (10,)
@@ -45,8 +45,8 @@
4545
end
4646
end
4747

48-
@testsuite "construct/similar" AT->begin
49-
for T in supported_eltypes()
48+
@testsuite "construct/similar" (AT, eltypes)->begin
49+
for T in eltypes
5050
B = AT{T}(undef, 10)
5151

5252
B = similar(B, Int32, 11, 15)
@@ -96,8 +96,8 @@ end
9696
end
9797
end
9898

99-
@testsuite "construct/convenience" AT->begin
100-
for T in supported_eltypes()
99+
@testsuite "construct/convenience" (AT, eltypes)->begin
100+
for T in eltypes
101101
A = AT(rand(T, 3))
102102
b = rand(T)
103103
fill!(A, b)
@@ -126,8 +126,8 @@ end
126126
end
127127
end
128128

129-
@testsuite "construct/conversions" AT->begin
130-
for T in supported_eltypes()
129+
@testsuite "construct/conversions" (AT, eltypes)->begin
130+
for T in eltypes
131131
Bc = round.(rand(10, 10) .* 10.0)
132132
B = AT{T}(Bc)
133133
@test size(B) == (10, 10)
@@ -146,16 +146,22 @@ end
146146
@test eltype(B) == T
147147
@test Array(B) Bc
148148

149-
Bc = rand(Int32, 3, 3, 3)
149+
intervals = Dict(
150+
Float16 => -2^11:2^11,
151+
Float32 => -2^24:2^24,
152+
Float64 => -2^53:2^53,
153+
)
154+
155+
Bc = rand(Int8, 3, 3, 3)
150156
B = convert(AT{T, 3}, Bc)
151157
@test size(B) == (3, 3, 3)
152158
@test eltype(B) == T
153159
@test Array(B) Bc
154160
end
155161
end
156162

157-
@testsuite "construct/uniformscaling" AT->begin
158-
for T in supported_eltypes()
163+
@testsuite "construct/uniformscaling" (AT, eltypes)->begin
164+
for T in eltypes
159165
x = Matrix{T}(I, 4, 2)
160166

161167
x1 = AT{T, 2}(I, 4, 2)

test/testsuite/gpuinterface.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
@testsuite "interface" AT->begin
1+
@testsuite "interface" (AT, eltypes)->begin
22
AT <: AbstractGPUArray || return
33

44
N = 10

0 commit comments

Comments
 (0)