Skip to content

Commit d6c1b6e

Browse files
committed
Use correct backend in examples
1 parent 110d784 commit d6c1b6e

File tree

7 files changed

+28
-25
lines changed

7 files changed

+28
-25
lines changed

examples/histogram.jl

Lines changed: 16 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -5,31 +5,28 @@ include(joinpath(dirname(pathof(KernelAbstractions)), "../examples/utils.jl")) #
55

66
# Function to use as a baseline for CPU metrics
77
function create_histogram(input)
8-
histogram_output = zeros(Int, maximum(input))
8+
histogram_output = zeros(eltype(input), maximum(input))
99
for i in input
1010
histogram_output[i] += 1
1111
end
1212
return histogram_output
1313
end
1414

1515
# This a 1D histogram kernel where the histogramming happens on shmem
16-
@kernel function histogram_kernel!(histogram_output, input)
16+
@kernel unsafe_indices=true function histogram_kernel!(histogram_output, input)
1717
tid = @index(Global, Linear)
1818
lid = @index(Local, Linear)
1919

20-
@uniform warpsize = Int(32)
21-
22-
@uniform gs = @groupsize()[1]
20+
@uniform gs = prod(@groupsize())
2321
@uniform N = length(histogram_output)
2422

25-
shared_histogram = @localmem Int (gs)
23+
shared_histogram = @localmem eltype(input) (gs)
2624

2725
# This will go through all input elements and assign them to a location in
2826
# shmem. Note that if there is not enough shem, we create different shmem
2927
# blocks to write to. For example, if shmem is of size 256, but it's
3028
# possible to get a value of 312, then we will have 2 separate shmem blocks,
3129
# one from 1->256, and another from 256->512
32-
@uniform max_element = 1
3330
for min_element in 1:gs:N
3431

3532
# Setting shared_histogram to 0
@@ -42,7 +39,7 @@ end
4239
end
4340

4441
# Defining bin on shared memory and writing to it if possible
45-
bin = input[tid]
42+
bin = tid <= length(input) ? input[tid] : 0
4643
if bin >= min_element && bin < max_element
4744
bin -= min_element - 1
4845
@atomic shared_histogram[bin] += 1
@@ -58,10 +55,10 @@ end
5855

5956
end
6057

61-
function histogram!(histogram_output, input)
58+
function histogram!(histogram_output, input, groupsize=256)
6259
backend = get_backend(histogram_output)
6360
# Need static block size
64-
kernel! = histogram_kernel!(backend, (256,))
61+
kernel! = histogram_kernel!(backend, (groupsize,))
6562
kernel!(histogram_output, input, ndrange = size(input))
6663
return
6764
end
@@ -74,9 +71,10 @@ function move(backend, input)
7471
end
7572

7673
@testset "histogram tests" begin
77-
rand_input = [rand(1:128) for i in 1:1000]
78-
linear_input = [i for i in 1:1024]
79-
all_two = [2 for i in 1:512]
74+
# Use Int32 as some backends don't support 64-bit atomics
75+
rand_input = Int32.(rand(1:128, 1000))
76+
linear_input = Int32.(1:1024)
77+
all_two = fill(Int32(2), 512)
8078

8179
histogram_rand_baseline = create_histogram(rand_input)
8280
histogram_linear_baseline = create_histogram(linear_input)
@@ -86,14 +84,14 @@ end
8684
linear_input = move(backend, linear_input)
8785
all_two = move(backend, all_two)
8886

89-
rand_histogram = KernelAbstractions.zeros(backend, Int, 128)
90-
linear_histogram = KernelAbstractions.zeros(backend, Int, 1024)
91-
two_histogram = KernelAbstractions.zeros(backend, Int, 2)
87+
rand_histogram = KernelAbstractions.zeros(backend, eltype(rand_input), Int(maximum(rand_input)))
88+
linear_histogram = KernelAbstractions.zeros(backend, eltype(linear_input), Int(maximum(linear_input)))
89+
two_histogram = KernelAbstractions.zeros(backend, eltype(all_two), Int(maximum(all_two)))
9290

93-
histogram!(rand_histogram, rand_input)
91+
histogram!(rand_histogram, rand_input, 6)
9492
histogram!(linear_histogram, linear_input)
9593
histogram!(two_histogram, all_two)
96-
KernelAbstractions.synchronize(CPU())
94+
KernelAbstractions.synchronize(backend)
9795

9896
@test isapprox(Array(rand_histogram), histogram_rand_baseline)
9997
@test isapprox(Array(linear_histogram), histogram_linear_baseline)

examples/memcopy.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@ function mycopy!(A, B)
1616
return
1717
end
1818

19-
A = KernelAbstractions.zeros(backend, Float64, 128, 128)
20-
B = KernelAbstractions.ones(backend, Float64, 128, 128)
19+
A = KernelAbstractions.zeros(backend, f_type, 128, 128)
20+
B = KernelAbstractions.ones(backend, f_type, 128, 128)
2121
mycopy!(A, B)
2222
KernelAbstractions.synchronize(backend)
2323
@test A == B

examples/memcopy_static.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@ function mycopy_static!(A, B)
1616
return
1717
end
1818

19-
A = KernelAbstractions.zeros(backend, Float64, 128, 128)
20-
B = KernelAbstractions.ones(backend, Float64, 128, 128)
19+
A = KernelAbstractions.zeros(backend, f_type, 128, 128)
20+
B = KernelAbstractions.ones(backend, f_type, 128, 128)
2121
mycopy_static!(A, B)
2222
KernelAbstractions.synchronize(backend)
2323
@test A == B

examples/performant_matmul.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ using Test
44
using Random
55
include(joinpath(dirname(pathof(KernelAbstractions)), "../examples/utils.jl")) # Load backend
66

7-
const TILE_DIM = 32
7+
const TILE_DIM = 16
88

99
@kernel function coalesced_matmul_kernel!(
1010
output, @Const(input1), @Const(input2), N, R, M,

examples/utils.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
# EXCLUDE FROM TESTING
2+
if !(@isdefined backend)
23
if Base.find_package("CUDA") !== nothing
34
using CUDA
45
using CUDA.CUDAKernels
@@ -7,3 +8,6 @@ if Base.find_package("CUDA") !== nothing
78
else
89
const backend = CPU()
910
end
11+
end
12+
13+
const f_type = KernelAbstractions.supports_float64(backend) ? Float64 : Float32

test/examples.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ function find_sources(path::String, sources = String[])
99
return sources
1010
end
1111

12-
function examples_testsuite(backend_str)
12+
function examples_testsuite(backend, backend_str)
1313
@testset "examples" begin
1414
examples_dir = joinpath(@__DIR__, "..", "examples")
1515
examples = find_sources(examples_dir)
@@ -21,6 +21,7 @@ function examples_testsuite(backend_str)
2121
@testset "$(basename(example))" for example in examples
2222
@eval module $(gensym())
2323
backend_str = $backend_str
24+
const backend = ($backend)()
2425
include($example)
2526
end
2627
@test true

test/testsuite.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ function testsuite(backend, backend_str, backend_mod, AT, DAT; skip_tests = Set{
8484
end
8585

8686
@conditional_testset "Examples" skip_tests begin
87-
examples_testsuite(backend_str)
87+
examples_testsuite(backend, backend_str)
8888
end
8989

9090
return

0 commit comments

Comments
 (0)