Skip to content

Commit b47ecf1

Browse files
authored
Simplify and fix launch configuration in bitonic sort (#1979)
1 parent c86f21f commit b47ecf1

File tree

1 file changed

+23
-23
lines changed

1 file changed

+23
-23
lines changed

src/sorting.jl

Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -640,8 +640,7 @@ than the size of block allows (eg, 1 <--> 10000)
640640
The grid index directly maps to the index of `c` that will be used in the swap.
641641
642642
Note that to avoid synchronization issues, only one thread from each pair of
643-
indices being swapped will actually move data. This does mean half of the threads
644-
do nothing, but it works for non-power2 arrays while allowing direct indexing.
643+
indices being swapped will actually move data.
645644
"""
646645
function comparator_kernel(vals, length_vals::I, k::I, j::I, by::F1, lt::F2,
647646
rev) where {I,F1,F2}
@@ -793,7 +792,12 @@ swaps in shared mem.
793792
Note that the x dimension of a thread block is treated as a comparator,
794793
so when the maximum size of a comparator in this kernel is small, multiple
795794
may be executed along the block y dimension, allowing for higher occupancy.
796-
This is captured by `pseudo_block_idx`.
795+
These threads in a block with the same threadIdx().x are a 'pseudo-block',
796+
and are indexed by `pseudo_block_idx`.
797+
798+
Unlike `comparator_kernel`, a thread's grid_index does not directly map to the
799+
index of `c` it will read from. `block_range` gives gives each pseudo-block
800+
a unique range of indices corresponding to a comparator in the sorting network.
797801
798802
Note that this moves the array values copied within shmem, but doesn't copy them
799803
back to global the way it does for indices.
@@ -859,18 +863,11 @@ function bitonic_sort!(c; by = identity, lt = isless, rev = false)
859863
args2 = (c, I(c_len), one(I), one(I), by, lt, Val(rev))
860864
kernel2 = @cuda launch=false comparator_kernel(args2...)
861865
config2 = launch_configuration(kernel2.fun, shmem = threads -> bitonic_shmem(c, threads))
866+
# blocksize for kernel2 MUST be a power of 2
862867
threads2 = prevpow(2, config2.threads)
863868

864-
# determine launch configuration
865-
blocks_per_mp = if CUDA.driver_version() >= v"11.0"
866-
CUDA.attribute(device(), CUDA.DEVICE_ATTRIBUTE_MAX_BLOCKS_PER_MULTIPROCESSOR)
867-
else
868-
16
869-
end
870-
blocks_per_mp = 16 # XXX: JuliaGPU/CUDA.jl#1874
871-
threads = min(threads1, threads2)
872-
min_pseudo_block = threads ÷ blocks_per_mp
873-
log_threads = threads |> log2 |> Int
869+
# determines cutoff for when to use kernel1 vs kernel2
870+
log_threads = threads2 |> log2 |> Int
874871

875872
# These two outer loops are the same as the serial version outlined here:
876873
# https://en.wikipedia.org/wiki/Bitonic_sorter#Example_code
@@ -882,19 +879,22 @@ function bitonic_sort!(c; by = identity, lt = isless, rev = false)
882879
args1 = (c, I.((c_len, k, j, j_final))..., by, lt, Val(rev))
883880
args2 = (c, I.((c_len, k, j))..., by, lt, Val(rev))
884881
if k0 - k - j + 2 <= log_threads
885-
pseudo_block_size = 1 << abs(j_final + 1 - j)
886-
block_size = if pseudo_block_size <= min_pseudo_block
887-
(pseudo_block_size, min_pseudo_block ÷ pseudo_block_size)
888-
else
889-
(pseudo_block_size, 1)
890-
end
891-
b = nextpow(2, cld(c_len, prod(block_size)))
892-
kernel1(args1...; blocks=b, threads=block_size,
882+
# pseudo_block_length = max(nextpow(2, length(comparator)) for all comparators in this layer of the network)
883+
pseudo_block_length = 1 << abs(j_final + 1 - j)
884+
# N_pseudo_blocks = how many pseudo-blocks are in this layer of the network
885+
N_pseudo_blocks = nextpow(2, c_len) ÷ pseudo_block_length
886+
pseudo_blocks_per_block = threads2 ÷ pseudo_block_length
887+
888+
# grid dimensions
889+
N_blocks = max(1, N_pseudo_blocks ÷ pseudo_blocks_per_block)
890+
block_size = pseudo_block_length, threads2 ÷ pseudo_block_length
891+
892+
kernel1(args1...; blocks=N_blocks, threads=block_size,
893893
shmem=bitonic_shmem(c, block_size))
894894
break
895895
else
896-
b = nextpow(2, cld(c_len, threads))
897-
kernel2(args2...; blocks = b, threads)
896+
N_blocks = cld(c_len, threads1)
897+
kernel2(args2...; blocks = N_blocks, threads=threads1)
898898
end
899899
end
900900
end

0 commit comments

Comments
 (0)