Skip to content

Commit dab03b9

Browse files
Test correct backend in examples test (#597)
* Use Float32 in examples with backends that don't support Float64 * Add `backend` argument for `examples_testset` * Reduce `TILE_DIM` for compatibility Metal doesn't always support 1-24 threads, which causes intermittent errors with 32x32 tiles * Fix histogram implementation The final part of the loop expects every thread to exists, so we cannot not launch them. Avoid work on extra threads until then. Also use Int32 since some backends lack Int64 atomics, and make one of the tests have weird groupsize since that's when the errors used to pop up.
1 parent 474050e commit dab03b9

File tree

7 files changed

+40
-33
lines changed

7 files changed

+40
-33
lines changed

examples/histogram.jl

Lines changed: 18 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -5,31 +5,29 @@ include(joinpath(dirname(pathof(KernelAbstractions)), "../examples/utils.jl")) #
55

66
# Function to use as a baseline for CPU metrics
77
function create_histogram(input)
8-
histogram_output = zeros(Int, maximum(input))
8+
histogram_output = zeros(eltype(input), maximum(input))
99
for i in input
1010
histogram_output[i] += 1
1111
end
1212
return histogram_output
1313
end
1414

1515
# This a 1D histogram kernel where the histogramming happens on shmem
16-
@kernel function histogram_kernel!(histogram_output, input)
17-
tid = @index(Global, Linear)
16+
@kernel unsafe_indices = true function histogram_kernel!(histogram_output, input)
17+
gid = @index(Group, Linear)
1818
lid = @index(Local, Linear)
1919

20-
@uniform warpsize = Int(32)
21-
22-
@uniform gs = @groupsize()[1]
20+
@uniform gs = prod(@groupsize())
21+
tid = (gid - 1) * gs + lid
2322
@uniform N = length(histogram_output)
2423

25-
shared_histogram = @localmem Int (gs)
24+
shared_histogram = @localmem eltype(input) (gs)
2625

2726
# This will go through all input elements and assign them to a location in
2827
# shmem. Note that if there is not enough shem, we create different shmem
2928
# blocks to write to. For example, if shmem is of size 256, but it's
3029
# possible to get a value of 312, then we will have 2 separate shmem blocks,
3130
# one from 1->256, and another from 256->512
32-
@uniform max_element = 1
3331
for min_element in 1:gs:N
3432

3533
# Setting shared_histogram to 0
@@ -42,7 +40,7 @@ end
4240
end
4341

4442
# Defining bin on shared memory and writing to it if possible
45-
bin = input[tid]
43+
bin = tid <= length(input) ? input[tid] : 0
4644
if bin >= min_element && bin < max_element
4745
bin -= min_element - 1
4846
@atomic shared_histogram[bin] += 1
@@ -58,10 +56,10 @@ end
5856

5957
end
6058

61-
function histogram!(histogram_output, input)
59+
function histogram!(histogram_output, input, groupsize = 256)
6260
backend = get_backend(histogram_output)
6361
# Need static block size
64-
kernel! = histogram_kernel!(backend, (256,))
62+
kernel! = histogram_kernel!(backend, (groupsize,))
6563
kernel!(histogram_output, input, ndrange = size(input))
6664
return
6765
end
@@ -74,9 +72,10 @@ function move(backend, input)
7472
end
7573

7674
@testset "histogram tests" begin
77-
rand_input = [rand(1:128) for i in 1:1000]
78-
linear_input = [i for i in 1:1024]
79-
all_two = [2 for i in 1:512]
75+
# Use Int32 as some backends don't support 64-bit atomics
76+
rand_input = Int32.(rand(1:128, 1000))
77+
linear_input = Int32.(1:1024)
78+
all_two = fill(Int32(2), 512)
8079

8180
histogram_rand_baseline = create_histogram(rand_input)
8281
histogram_linear_baseline = create_histogram(linear_input)
@@ -86,14 +85,14 @@ end
8685
linear_input = move(backend, linear_input)
8786
all_two = move(backend, all_two)
8887

89-
rand_histogram = KernelAbstractions.zeros(backend, Int, 128)
90-
linear_histogram = KernelAbstractions.zeros(backend, Int, 1024)
91-
two_histogram = KernelAbstractions.zeros(backend, Int, 2)
88+
rand_histogram = KernelAbstractions.zeros(backend, eltype(rand_input), Int(maximum(rand_input)))
89+
linear_histogram = KernelAbstractions.zeros(backend, eltype(linear_input), Int(maximum(linear_input)))
90+
two_histogram = KernelAbstractions.zeros(backend, eltype(all_two), Int(maximum(all_two)))
9291

93-
histogram!(rand_histogram, rand_input)
92+
histogram!(rand_histogram, rand_input, 6)
9493
histogram!(linear_histogram, linear_input)
9594
histogram!(two_histogram, all_two)
96-
KernelAbstractions.synchronize(CPU())
95+
KernelAbstractions.synchronize(backend)
9796

9897
@test isapprox(Array(rand_histogram), histogram_rand_baseline)
9998
@test isapprox(Array(linear_histogram), histogram_linear_baseline)

examples/memcopy.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@ function mycopy!(A, B)
1616
return
1717
end
1818

19-
A = KernelAbstractions.zeros(backend, Float64, 128, 128)
20-
B = KernelAbstractions.ones(backend, Float64, 128, 128)
19+
A = KernelAbstractions.zeros(backend, f_type, 128, 128)
20+
B = KernelAbstractions.ones(backend, f_type, 128, 128)
2121
mycopy!(A, B)
2222
KernelAbstractions.synchronize(backend)
2323
@test A == B

examples/memcopy_static.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@ function mycopy_static!(A, B)
1616
return
1717
end
1818

19-
A = KernelAbstractions.zeros(backend, Float64, 128, 128)
20-
B = KernelAbstractions.ones(backend, Float64, 128, 128)
19+
A = KernelAbstractions.zeros(backend, f_type, 128, 128)
20+
B = KernelAbstractions.ones(backend, f_type, 128, 128)
2121
mycopy_static!(A, B)
2222
KernelAbstractions.synchronize(backend)
2323
@test A == B

examples/performant_matmul.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,10 @@ using Test
44
using Random
55
include(joinpath(dirname(pathof(KernelAbstractions)), "../examples/utils.jl")) # Load backend
66

7-
const TILE_DIM = 32
7+
# We use a TILE_DIM of 16 as a safe value since while
8+
# most backends support up to 1024 threads per group,
9+
# Metal sometimes supports fewer.
10+
const TILE_DIM = 16
811

912
@kernel function coalesced_matmul_kernel!(
1013
output, @Const(input1), @Const(input2), N, R, M,

examples/utils.jl

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
11
# EXCLUDE FROM TESTING
2-
if Base.find_package("CUDA") !== nothing
3-
using CUDA
4-
using CUDA.CUDAKernels
5-
const backend = CUDABackend()
6-
CUDA.allowscalar(false)
7-
else
8-
const backend = CPU()
2+
if !(@isdefined backend)
3+
if Base.find_package("CUDA") !== nothing
4+
using CUDA
5+
using CUDA.CUDAKernels
6+
const backend = CUDABackend()
7+
CUDA.allowscalar(false)
8+
else
9+
const backend = CPU()
10+
end
911
end
12+
13+
const f_type = KernelAbstractions.supports_float64(backend) ? Float64 : Float32

test/examples.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ function find_sources(path::String, sources = String[])
99
return sources
1010
end
1111

12-
function examples_testsuite(backend_str)
12+
function examples_testsuite(backend, backend_str)
1313
@testset "examples" begin
1414
examples_dir = joinpath(@__DIR__, "..", "examples")
1515
examples = find_sources(examples_dir)
@@ -21,6 +21,7 @@ function examples_testsuite(backend_str)
2121
@testset "$(basename(example))" for example in examples
2222
@eval module $(gensym())
2323
backend_str = $backend_str
24+
const backend = ($backend)()
2425
include($example)
2526
end
2627
@test true

test/testsuite.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ function testsuite(backend, backend_str, backend_mod, AT, DAT; skip_tests = Set{
8484
end
8585

8686
@conditional_testset "Examples" skip_tests begin
87-
examples_testsuite(backend_str)
87+
examples_testsuite(backend, backend_str)
8888
end
8989

9090
return

0 commit comments

Comments
 (0)