@@ -640,8 +640,7 @@ than the size of block allows (eg, 1 <--> 10000)
640
640
The grid index directly maps to the index of `c` that will be used in the swap.
641
641
642
642
Note that to avoid synchronization issues, only one thread from each pair of
643
- indices being swapped will actually move data. This does mean half of the threads
644
- do nothing, but it works for non-power2 arrays while allowing direct indexing.
643
+ indices being swapped will actually move data.
645
644
"""
646
645
function comparator_kernel (vals, length_vals:: I , k:: I , j:: I , by:: F1 , lt:: F2 ,
647
646
rev) where {I,F1,F2}
@@ -793,7 +792,12 @@ swaps in shared mem.
793
792
Note that the x dimension of a thread block is treated as a comparator,
794
793
so when the maximum size of a comparator in this kernel is small, multiple
795
794
may be executed along the block y dimension, allowing for higher occupancy.
796
- This is captured by `pseudo_block_idx`.
795
+ These threads in a block with the same threadIdx().x are a 'pseudo-block',
796
+ and are indexed by `pseudo_block_idx`.
797
+
798
+ Unlike `comparator_kernel`, a thread's grid_index does not directly map to the
799
+ index of `c` it will read from. `block_range` gives gives each pseudo-block
800
+ a unique range of indices corresponding to a comparator in the sorting network.
797
801
798
802
Note that this moves the array values copied within shmem, but doesn't copy them
799
803
back to global the way it does for indices.
@@ -859,18 +863,11 @@ function bitonic_sort!(c; by = identity, lt = isless, rev = false)
859
863
args2 = (c, I (c_len), one (I), one (I), by, lt, Val (rev))
860
864
kernel2 = @cuda launch= false comparator_kernel (args2... )
861
865
config2 = launch_configuration (kernel2. fun, shmem = threads -> bitonic_shmem (c, threads))
866
+ # blocksize for kernel2 MUST be a power of 2
862
867
threads2 = prevpow (2 , config2. threads)
863
868
864
- # determine launch configuration
865
- blocks_per_mp = if CUDA. driver_version () >= v " 11.0"
866
- CUDA. attribute (device (), CUDA. DEVICE_ATTRIBUTE_MAX_BLOCKS_PER_MULTIPROCESSOR)
867
- else
868
- 16
869
- end
870
- blocks_per_mp = 16 # XXX : JuliaGPU/CUDA.jl#1874
871
- threads = min (threads1, threads2)
872
- min_pseudo_block = threads ÷ blocks_per_mp
873
- log_threads = threads |> log2 |> Int
869
+ # determines cutoff for when to use kernel1 vs kernel2
870
+ log_threads = threads2 |> log2 |> Int
874
871
875
872
# These two outer loops are the same as the serial version outlined here:
876
873
# https://en.wikipedia.org/wiki/Bitonic_sorter#Example_code
@@ -882,19 +879,22 @@ function bitonic_sort!(c; by = identity, lt = isless, rev = false)
882
879
args1 = (c, I .((c_len, k, j, j_final))... , by, lt, Val (rev))
883
880
args2 = (c, I .((c_len, k, j))... , by, lt, Val (rev))
884
881
if k0 - k - j + 2 <= log_threads
885
- pseudo_block_size = 1 << abs (j_final + 1 - j)
886
- block_size = if pseudo_block_size <= min_pseudo_block
887
- (pseudo_block_size, min_pseudo_block ÷ pseudo_block_size)
888
- else
889
- (pseudo_block_size, 1 )
890
- end
891
- b = nextpow (2 , cld (c_len, prod (block_size)))
892
- kernel1 (args1... ; blocks= b, threads= block_size,
882
+ # pseudo_block_length = max(nextpow(2, length(comparator)) for all comparators in this layer of the network)
883
+ pseudo_block_length = 1 << abs (j_final + 1 - j)
884
+ # N_pseudo_blocks = how many pseudo-blocks are in this layer of the network
885
+ N_pseudo_blocks = nextpow (2 , c_len) ÷ pseudo_block_length
886
+ pseudo_blocks_per_block = threads2 ÷ pseudo_block_length
887
+
888
+ # grid dimensions
889
+ N_blocks = max (1 , N_pseudo_blocks ÷ pseudo_blocks_per_block)
890
+ block_size = pseudo_block_length, threads2 ÷ pseudo_block_length
891
+
892
+ kernel1 (args1... ; blocks= N_blocks, threads= block_size,
893
893
shmem= bitonic_shmem (c, block_size))
894
894
break
895
895
else
896
- b = nextpow ( 2 , cld (c_len, threads) )
897
- kernel2 (args2... ; blocks = b , threads)
896
+ N_blocks = cld (c_len, threads1 )
897
+ kernel2 (args2... ; blocks = N_blocks , threads= threads1 )
898
898
end
899
899
end
900
900
end
0 commit comments