@@ -5,31 +5,29 @@ include(joinpath(dirname(pathof(KernelAbstractions)), "../examples/utils.jl")) #
5
5
6
6
# Function to use as a baseline for CPU metrics
7
7
function create_histogram (input)
8
- histogram_output = zeros (Int , maximum (input))
8
+ histogram_output = zeros (eltype (input) , maximum (input))
9
9
for i in input
10
10
histogram_output[i] += 1
11
11
end
12
12
return histogram_output
13
13
end
14
14
15
15
# This a 1D histogram kernel where the histogramming happens on shmem
16
- @kernel function histogram_kernel! (histogram_output, input)
17
- tid = @index (Global , Linear)
16
+ @kernel unsafe_indices = true function histogram_kernel! (histogram_output, input)
17
+ gid = @index (Group , Linear)
18
18
lid = @index (Local, Linear)
19
19
20
- @uniform warpsize = Int (32 )
21
-
22
- @uniform gs = @groupsize ()[1 ]
20
+ @uniform gs = prod (@groupsize ())
21
+ tid = (gid - 1 ) * gs + lid
23
22
@uniform N = length (histogram_output)
24
23
25
- shared_histogram = @localmem Int (gs)
24
+ shared_histogram = @localmem eltype (input) (gs)
26
25
27
26
# This will go through all input elements and assign them to a location in
28
27
# shmem. Note that if there is not enough shem, we create different shmem
29
28
# blocks to write to. For example, if shmem is of size 256, but it's
30
29
# possible to get a value of 312, then we will have 2 separate shmem blocks,
31
30
# one from 1->256, and another from 256->512
32
- @uniform max_element = 1
33
31
for min_element in 1 : gs: N
34
32
35
33
# Setting shared_histogram to 0
42
40
end
43
41
44
42
# Defining bin on shared memory and writing to it if possible
45
- bin = input[tid]
43
+ bin = tid <= length ( input) ? input [tid] : 0
46
44
if bin >= min_element && bin < max_element
47
45
bin -= min_element - 1
48
46
@atomic shared_histogram[bin] += 1
58
56
59
57
end
60
58
61
- function histogram! (histogram_output, input)
59
+ function histogram! (histogram_output, input, groupsize = 256 )
62
60
backend = get_backend (histogram_output)
63
61
# Need static block size
64
- kernel! = histogram_kernel! (backend, (256 ,))
62
+ kernel! = histogram_kernel! (backend, (groupsize ,))
65
63
kernel! (histogram_output, input, ndrange = size (input))
66
64
return
67
65
end
@@ -74,9 +72,10 @@ function move(backend, input)
74
72
end
75
73
76
74
@testset " histogram tests" begin
77
- rand_input = [rand (1 : 128 ) for i in 1 : 1000 ]
78
- linear_input = [i for i in 1 : 1024 ]
79
- all_two = [2 for i in 1 : 512 ]
75
+ # Use Int32 as some backends don't support 64-bit atomics
76
+ rand_input = Int32 .(rand (1 : 128 , 1000 ))
77
+ linear_input = Int32 .(1 : 1024 )
78
+ all_two = fill (Int32 (2 ), 512 )
80
79
81
80
histogram_rand_baseline = create_histogram (rand_input)
82
81
histogram_linear_baseline = create_histogram (linear_input)
86
85
linear_input = move (backend, linear_input)
87
86
all_two = move (backend, all_two)
88
87
89
- rand_histogram = KernelAbstractions. zeros (backend, Int, 128 )
90
- linear_histogram = KernelAbstractions. zeros (backend, Int, 1024 )
91
- two_histogram = KernelAbstractions. zeros (backend, Int, 2 )
88
+ rand_histogram = KernelAbstractions. zeros (backend, eltype (rand_input), Int ( maximum (rand_input)) )
89
+ linear_histogram = KernelAbstractions. zeros (backend, eltype (linear_input), Int ( maximum (linear_input)) )
90
+ two_histogram = KernelAbstractions. zeros (backend, eltype (all_two), Int ( maximum (all_two)) )
92
91
93
- histogram! (rand_histogram, rand_input)
92
+ histogram! (rand_histogram, rand_input, 6 )
94
93
histogram! (linear_histogram, linear_input)
95
94
histogram! (two_histogram, all_two)
96
- KernelAbstractions. synchronize (CPU () )
95
+ KernelAbstractions. synchronize (backend )
97
96
98
97
@test isapprox (Array (rand_histogram), histogram_rand_baseline)
99
98
@test isapprox (Array (linear_histogram), histogram_linear_baseline)
0 commit comments