Skip to content

Commit d35c25c

Browse files
committed
Use correck backend in examples
1 parent 110d784 commit d35c25c

File tree

7 files changed

+25
-16
lines changed

7 files changed

+25
-16
lines changed

examples/histogram.jl

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ include(joinpath(dirname(pathof(KernelAbstractions)), "../examples/utils.jl")) #
55

66
# Function to use as a baseline for CPU metrics
77
function create_histogram(input)
8-
histogram_output = zeros(Int, maximum(input))
8+
histogram_output = zeros(eltype(input), maximum(input))
99
for i in input
1010
histogram_output[i] += 1
1111
end
@@ -22,7 +22,7 @@ end
2222
@uniform gs = @groupsize()[1]
2323
@uniform N = length(histogram_output)
2424

25-
shared_histogram = @localmem Int (gs)
25+
shared_histogram = @localmem eltype(input) (gs)
2626

2727
# This will go through all input elements and assign them to a location in
2828
# shmem. Note that if there is not enough shem, we create different shmem
@@ -74,9 +74,10 @@ function move(backend, input)
7474
end
7575

7676
@testset "histogram tests" begin
77-
rand_input = [rand(1:128) for i in 1:1000]
78-
linear_input = [i for i in 1:1024]
79-
all_two = [2 for i in 1:512]
77+
# Use Int32 as some backends don't support 64-bit atomics
78+
rand_input = Int32.(rand(1:128, 1000))
79+
linear_input = Int32.(rand(1:128, 1024))
80+
all_two = fill(Int32(2), 512)
8081

8182
histogram_rand_baseline = create_histogram(rand_input)
8283
histogram_linear_baseline = create_histogram(linear_input)
@@ -86,14 +87,14 @@ end
8687
linear_input = move(backend, linear_input)
8788
all_two = move(backend, all_two)
8889

89-
rand_histogram = KernelAbstractions.zeros(backend, Int, 128)
90-
linear_histogram = KernelAbstractions.zeros(backend, Int, 1024)
91-
two_histogram = KernelAbstractions.zeros(backend, Int, 2)
90+
rand_histogram = KernelAbstractions.zeros(backend, eltype(rand_input), 128)
91+
linear_histogram = KernelAbstractions.zeros(backend, eltype(linear_input), 1024)
92+
two_histogram = KernelAbstractions.zeros(backend, eltype(all_two), 2)
9293

9394
histogram!(rand_histogram, rand_input)
9495
histogram!(linear_histogram, linear_input)
9596
histogram!(two_histogram, all_two)
96-
KernelAbstractions.synchronize(CPU())
97+
KernelAbstractions.synchronize(backend)
9798

9899
@test isapprox(Array(rand_histogram), histogram_rand_baseline)
99100
@test isapprox(Array(linear_histogram), histogram_linear_baseline)

examples/memcopy.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@ function mycopy!(A, B)
1616
return
1717
end
1818

19-
A = KernelAbstractions.zeros(backend, Float64, 128, 128)
20-
B = KernelAbstractions.ones(backend, Float64, 128, 128)
19+
A = KernelAbstractions.zeros(backend, f_type, 128, 128)
20+
B = KernelAbstractions.ones(backend, f_type, 128, 128)
2121
mycopy!(A, B)
2222
KernelAbstractions.synchronize(backend)
2323
@test A == B

examples/memcopy_static.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@ function mycopy_static!(A, B)
1616
return
1717
end
1818

19-
A = KernelAbstractions.zeros(backend, Float64, 128, 128)
20-
B = KernelAbstractions.ones(backend, Float64, 128, 128)
19+
A = KernelAbstractions.zeros(backend, f_type, 128, 128)
20+
B = KernelAbstractions.ones(backend, f_type, 128, 128)
2121
mycopy_static!(A, B)
2222
KernelAbstractions.synchronize(backend)
2323
@test A == B

examples/performant_matmul.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ using Test
44
using Random
55
include(joinpath(dirname(pathof(KernelAbstractions)), "../examples/utils.jl")) # Load backend
66

7-
const TILE_DIM = 32
7+
const TILE_DIM = 16
88

99
@kernel function coalesced_matmul_kernel!(
1010
output, @Const(input1), @Const(input2), N, R, M,

examples/utils.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
# EXCLUDE FROM TESTING
2+
if !(@isdefined backend)
23
if Base.find_package("CUDA") !== nothing
34
using CUDA
45
using CUDA.CUDAKernels
@@ -7,3 +8,6 @@ if Base.find_package("CUDA") !== nothing
78
else
89
const backend = CPU()
910
end
11+
end
12+
13+
const f_type = KernelAbstractions.supports_float64(backend) ? Float64 : Float32

test/examples.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,18 +9,22 @@ function find_sources(path::String, sources = String[])
99
return sources
1010
end
1111

12-
function examples_testsuite(backend_str)
12+
function examples_testsuite(backend, backend_str)
1313
@testset "examples" begin
1414
examples_dir = joinpath(@__DIR__, "..", "examples")
1515
examples = find_sources(examples_dir)
1616
filter!(file -> readline(file) != "# EXCLUDE FROM TESTING", examples)
17+
filter!(examples) do file
18+
last(splitpath(file)) == "histogram.jl"
19+
end
1720
if backend_str == "ROCM"
1821
filter!(file -> occursin("# INCLUDE ROCM", String(read(file))), examples)
1922
end
2023

2124
@testset "$(basename(example))" for example in examples
2225
@eval module $(gensym())
2326
backend_str = $backend_str
27+
const backend = ($backend)()
2428
include($example)
2529
end
2630
@test true

test/testsuite.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ function testsuite(backend, backend_str, backend_mod, AT, DAT; skip_tests = Set{
8484
end
8585

8686
@conditional_testset "Examples" skip_tests begin
87-
examples_testsuite(backend_str)
87+
examples_testsuite(backend, backend_str)
8888
end
8989

9090
return

0 commit comments

Comments
 (0)