Skip to content

Commit c1afac3

Browse files
committed
Fix histogram implementation
The final part of the loop expects every thread to exists, so we cannot not launch them. Avoid work on extra threads until then. Also use Int32 since some backends lack Int64 atomics, and make one of the tests have weird groupsize since that's when the errors used to pop up.
1 parent 88bc615 commit c1afac3

File tree

1 file changed

+18
-19
lines changed

1 file changed

+18
-19
lines changed

examples/histogram.jl

Lines changed: 18 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -5,31 +5,29 @@ include(joinpath(dirname(pathof(KernelAbstractions)), "../examples/utils.jl")) #
55

66
# Function to use as a baseline for CPU metrics
77
function create_histogram(input)
8-
histogram_output = zeros(Int, maximum(input))
8+
histogram_output = zeros(eltype(input), maximum(input))
99
for i in input
1010
histogram_output[i] += 1
1111
end
1212
return histogram_output
1313
end
1414

1515
# This a 1D histogram kernel where the histogramming happens on shmem
16-
@kernel function histogram_kernel!(histogram_output, input)
17-
tid = @index(Global, Linear)
16+
@kernel unsafe_indices = true function histogram_kernel!(histogram_output, input)
17+
gid = @index(Group, Linear)
1818
lid = @index(Local, Linear)
1919

20-
@uniform warpsize = Int(32)
21-
22-
@uniform gs = @groupsize()[1]
20+
@uniform gs = prod(@groupsize())
21+
tid = (gid - 1) * gs + lid
2322
@uniform N = length(histogram_output)
2423

25-
shared_histogram = @localmem Int (gs)
24+
shared_histogram = @localmem eltype(input) (gs)
2625

2726
# This will go through all input elements and assign them to a location in
2827
# shmem. Note that if there is not enough shem, we create different shmem
2928
# blocks to write to. For example, if shmem is of size 256, but it's
3029
# possible to get a value of 312, then we will have 2 separate shmem blocks,
3130
# one from 1->256, and another from 256->512
32-
@uniform max_element = 1
3331
for min_element in 1:gs:N
3432

3533
# Setting shared_histogram to 0
@@ -42,7 +40,7 @@ end
4240
end
4341

4442
# Defining bin on shared memory and writing to it if possible
45-
bin = input[tid]
43+
bin = tid <= length(input) ? input[tid] : 0
4644
if bin >= min_element && bin < max_element
4745
bin -= min_element - 1
4846
@atomic shared_histogram[bin] += 1
@@ -58,10 +56,10 @@ end
5856

5957
end
6058

61-
function histogram!(histogram_output, input)
59+
function histogram!(histogram_output, input, groupsize = 256)
6260
backend = get_backend(histogram_output)
6361
# Need static block size
64-
kernel! = histogram_kernel!(backend, (256,))
62+
kernel! = histogram_kernel!(backend, (groupsize,))
6563
kernel!(histogram_output, input, ndrange = size(input))
6664
return
6765
end
@@ -74,9 +72,10 @@ function move(backend, input)
7472
end
7573

7674
@testset "histogram tests" begin
77-
rand_input = [rand(1:128) for i in 1:1000]
78-
linear_input = [i for i in 1:1024]
79-
all_two = [2 for i in 1:512]
75+
# Use Int32 as some backends don't support 64-bit atomics
76+
rand_input = Int32.(rand(1:128, 1000))
77+
linear_input = Int32.(1:1024)
78+
all_two = fill(Int32(2), 512)
8079

8180
histogram_rand_baseline = create_histogram(rand_input)
8281
histogram_linear_baseline = create_histogram(linear_input)
@@ -86,14 +85,14 @@ end
8685
linear_input = move(backend, linear_input)
8786
all_two = move(backend, all_two)
8887

89-
rand_histogram = KernelAbstractions.zeros(backend, Int, 128)
90-
linear_histogram = KernelAbstractions.zeros(backend, Int, 1024)
91-
two_histogram = KernelAbstractions.zeros(backend, Int, 2)
88+
rand_histogram = KernelAbstractions.zeros(backend, eltype(rand_input), Int(maximum(rand_input)))
89+
linear_histogram = KernelAbstractions.zeros(backend, eltype(linear_input), Int(maximum(linear_input)))
90+
two_histogram = KernelAbstractions.zeros(backend, eltype(all_two), Int(maximum(all_two)))
9291

93-
histogram!(rand_histogram, rand_input)
92+
histogram!(rand_histogram, rand_input, 6)
9493
histogram!(linear_histogram, linear_input)
9594
histogram!(two_histogram, all_two)
96-
KernelAbstractions.synchronize(CPU())
95+
KernelAbstractions.synchronize(backend)
9796

9897
@test isapprox(Array(rand_histogram), histogram_rand_baseline)
9998
@test isapprox(Array(linear_histogram), histogram_linear_baseline)

0 commit comments

Comments
 (0)