@@ -69,7 +69,8 @@ affected by `parity`. See `flex_lt`. `swap` is an array for exchanging values
69
69
and `sums` is an array of Ints used during the merge sort.
70
70
Uses block y index to decide which values to operate on.
71
71
"""
72
- @inline function batch_partition (values, pivot, swap, sums, lo, hi, parity, lt:: F1 , by:: F2 ) where {F1,F2}
72
+ @inline function batch_partition (values, pivot, swap, sums, lo, hi, parity,
73
+ lt:: F1 , by:: F2 ) where {F1,F2}
73
74
sync_threads ()
74
75
blockIdx_yz = (blockIdx (). z - 1 i32) * gridDim (). y + blockIdx (). y
75
76
idx0 = lo + (blockIdx_yz - 1 i32) * blockDim (). x + threadIdx (). x
@@ -88,7 +89,11 @@ Uses block y index to decide which values to operate on.
88
89
cumsum! (sums)
89
90
90
91
@inbounds if idx0 <= hi
91
- dest_idx = @inbounds comparison ? blockDim (). x - sums[end ] + sums[threadIdx (). x] : threadIdx (). x - sums[threadIdx (). x]
92
+ dest_idx = @inbounds if comparison
93
+ blockDim (). x - sums[end ] + sums[threadIdx (). x]
94
+ else
95
+ threadIdx (). x - sums[threadIdx (). x]
96
+ end
92
97
if dest_idx <= length (swap)
93
98
swap[dest_idx] = val
94
99
end
@@ -211,7 +216,7 @@ Finds the median of `vals` starting after `lo` and going for `blockDim().x`
211
216
elements spaced by `stride`. Performs bitonic sort in shmem, returns middle value.
212
217
Faster than bubble sort, but not as flexible. Does not modify `vals`
213
218
"""
214
- function bitonic_median (vals :: AbstractArray{T} , swap, lo, L, stride, lt:: F1 , by:: F2 ) where {T,F1,F2}
219
+ function bitonic_median (vals:: AbstractArray{T} , swap, lo, L, stride, lt:: F1 , by:: F2 ) where {T,F1,F2}
215
220
sync_threads ()
216
221
bitonic_lt (i1, i2) = @inbounds flex_lt (swap[i1 + 1 ], swap[i2 + 1 ], false , lt, by)
217
222
@@ -337,7 +342,7 @@ Quicksort recursion condition
337
342
For a full sort, `partial` is nothing so it shouldn't affect whether recursion
338
343
happens.
339
344
"""
340
- function partial_range_overlap (lo, hi, partial :: Nothing )
345
+ function partial_range_overlap (lo, hi, partial:: Nothing )
341
346
true
342
347
end
343
348
@@ -374,7 +379,8 @@ it's possible that the first pivot will be that value, which could lead to an in
374
379
early end to recursion if we started `stuck` at 0.
375
380
"""
376
381
function qsort_kernel (vals:: AbstractArray{T,N} , lo, hi, parity, sync:: Val{S} , sync_depth,
377
- prev_pivot, lt:: F1 , by:: F2 , :: Val{dims} , partial= nothing , stuck= - 1 ) where {T, N, S, F1, F2, dims}
382
+ prev_pivot, lt:: F1 , by:: F2 , :: Val{dims} , partial= nothing ,
383
+ stuck= - 1 ) where {T, N, S, F1, F2, dims}
378
384
b_sums = CuDynamicSharedArray (Int, blockDim (). x)
379
385
swap = CuDynamicSharedArray (T, blockDim (). x, sizeof (b_sums))
380
386
shmem = sizeof (b_sums) + sizeof (swap)
@@ -449,7 +455,7 @@ function qsort_kernel(vals::AbstractArray{T,N}, lo, hi, parity, sync::Val{S}, sy
449
455
return
450
456
end
451
457
452
- function sort_args (args, partial_k :: Nothing )
458
+ function sort_args (args, partial_k:: Nothing )
453
459
return args
454
460
end
455
461
@@ -578,10 +584,43 @@ end
578
584
end
579
585
end
580
586
587
+ @inline function extraneous_block (vals:: AbstractArray , dims):: Bool
588
+ other_linear_index = ((gridDim (). z ÷ blockDim (). z) * (blockIdx (). y - 1 )) + blockIdx (). z
589
+ return other_linear_index > length (vals) ÷ size (vals)[dims]
590
+ end
591
+
592
+ @inline function extraneous_block (vals, dims):: Bool
593
+ return extraneous_block (vals[1 ], dims)
594
+ end
595
+
596
+ # methods are defined for Val{1} because using view has 2x speed penalty for 1D arrays
597
+ @inline function view_along_dims (vals:: AbstractArray{T, 1} , dimsval:: Val{1} ) where T
598
+ return vals
599
+ end
600
+
601
+ @inline function view_along_dims (vals:: Tuple{AbstractArray{T,1},Any} , dimsval:: Val{1} ) where T
602
+ return vals[1 ], view_along_dims (vals[2 ], dimsval)
603
+ end
604
+
605
+
606
+ @inline function view_along_dims (vals:: AbstractArray{T, N} , :: Val{dims} ) where {T,N,dims}
607
+ otherdims = ntuple (i -> i == dims ? 1 : size (vals, i), N)
608
+ other_linear_index = ((gridDim (). z ÷ blockDim (). z) * (blockIdx (). y - 1 )) + blockIdx (). z
609
+ other = CartesianIndices (otherdims)[other_linear_index]
610
+ # create a view that keeps the sorting dimension but indexes across the others
611
+ slicedims = map (Base. Slice, axes (vals))
612
+ idxs = ntuple (i-> i== dims ? slicedims[i] : other[i], N)
613
+ return view (vals, idxs... )
614
+ end
615
+
616
+ @inline function view_along_dims (vals, dimsval:: Val{dims} ) where dims
617
+ return vals[1 ], view_along_dims (vals[2 ], dimsval)
618
+ end
619
+
620
+
581
621
# Functions specifically for "large" bitonic steps (those that cannot use shmem)
582
622
583
- @inline function compare! (vals:: AbstractArray{T} , i1:: I , i2:: I , dir:: Bool , by, lt,
584
- rev) where {T,I}
623
+ @inline function compare! (vals:: AbstractArray{T, N} , i1:: I , i2:: I , dir:: Bool , by, lt, rev) where {T,I,N}
585
624
i1′, i2′ = i1 + one (I), i2 + one (I)
586
625
@inbounds if dir != rev_lt (by (vals[i1′]), by (vals[i2′]), lt, rev)
587
626
vals[i1′], vals[i2′] = vals[i2′], vals[i1′]
593
632
i1′, i2′ = i1 + one (I), i2 + one (I)
594
633
vals, inds = vals_inds
595
634
# comparing tuples of (value, index) guarantees stability of sort
596
- @inbounds if dir != rev_lt ((by (vals[inds[i1′]]), inds[i1′]), (by (vals[inds[i2′]]), inds[i2′]), lt, rev)
635
+ @inbounds if dir != rev_lt ((by (vals[inds[i1′]]), inds[i1′]),
636
+ (by (vals[inds[i2′]]), inds[i2′]), lt, rev)
597
637
inds[i1′], inds[i2′] = inds[i2′], inds[i1′]
598
638
end
599
639
end
@@ -645,16 +685,22 @@ Note that to avoid synchronization issues, only one thread from each pair of
645
685
indices being swapped will actually move data.
646
686
"""
647
687
function comparator_kernel (vals, length_vals:: I , k:: I , j:: I , by:: F1 , lt:: F2 ,
648
- rev) where {I,F1,F2}
688
+ rev, dimsval:: Val{dims} ) where {I,F1,F2,dims}
689
+ if extraneous_block (vals, dims)
690
+ return nothing
691
+ end
692
+
649
693
index = (blockDim (). x * (blockIdx (). x - one (I))) + threadIdx (). x - one (I)
650
694
695
+ slice = view_along_dims (vals, dimsval)
696
+
651
697
lo, n, dir = get_range (length_vals, index, k, j)
652
698
653
699
if ! (lo < zero (I) || n < zero (I)) && ! (index >= length_vals)
654
700
m = gp2lt (n)
655
701
if lo <= index < lo + n - m
656
702
i1, i2 = index, index + m
657
- @inbounds compare! (vals , i1, i2, dir, by, lt, rev)
703
+ @inbounds compare! (slice , i1, i2, dir, by, lt, rev)
658
704
end
659
705
end
660
706
return
@@ -804,15 +850,19 @@ a unique range of indices corresponding to a comparator in the sorting network.
804
850
Note that this moves the array values copied within shmem, but doesn't copy them
805
851
back to global the way it does for indices.
806
852
"""
807
- function comparator_small_kernel (c, length_c:: I , k:: I , j_0:: I , j_f:: I ,
808
- by:: F1 , lt:: F2 , rev) where {I,F1,F2}
853
+ function comparator_small_kernel (vals, length_vals:: I , k:: I , j_0:: I , j_f:: I ,
854
+ by:: F1 , lt:: F2 , rev, dimsval:: Val{dims} ) where {I,F1,F2,dims}
855
+ if extraneous_block (vals, dims)
856
+ return nothing
857
+ end
858
+ slice = view_along_dims (vals, dimsval)
809
859
pseudo_block_idx = (blockIdx (). x - one (I)) * blockDim (). y + threadIdx (). y - one (I)
810
860
# immutable info about the range used by this kernel
811
- _lo, _n, dir = block_range (length_c , pseudo_block_idx, k, j_0)
861
+ _lo, _n, dir = block_range (length_vals , pseudo_block_idx, k, j_0)
812
862
index = _lo + threadIdx (). x - one (I)
813
863
in_range = (threadIdx (). x <= _n && _lo >= zero (I))
814
864
815
- swap = initialize_shmem! (c , index, in_range)
865
+ swap = initialize_shmem! (slice , index, in_range)
816
866
817
867
# mutable copies for pseudo-recursion
818
868
lo, n = _lo, _n
@@ -829,7 +879,7 @@ function comparator_small_kernel(c, length_c::I, k::I, j_0::I, j_f::I,
829
879
sync_threads ()
830
880
end
831
881
832
- finalize_shmem! (c , swap, index, in_range)
882
+ finalize_shmem! (slice , swap, index, in_range)
833
883
return
834
884
end
835
885
@@ -849,20 +899,20 @@ of values and an index array for doing `sortperm!`. Cannot provide a stable
849
899
`sort!` although `sortperm!` is properly stable. To reverse, set `rev=true`
850
900
rather than `lt=!isless` (otherwise stability of sortperm breaks down).
851
901
"""
852
- function bitonic_sort! (c; by = identity, lt = isless, rev = false )
853
- c_len = if typeof (c) <: Tuple
854
- length (c[1 ])
902
+ function bitonic_sort! (c; by = identity, lt = isless, rev = false , dims = 1 )
903
+ c_len, otherdims_len = if typeof (c) <: Tuple
904
+ size (c[ 1 ])[dims], length (c[1 ]) ÷ size (c[ 1 ])[dims]
855
905
else
856
- length (c)
906
+ size (c)[dims], length (c) ÷ size (c)[dims]
857
907
end
858
908
859
- # compile kernels (using Int32 for indexing, if possible, yielding a 10 % speedup)
909
+ # compile kernels (using Int32 for indexing, if possible, yielding a 70 % speedup)
860
910
I = c_len <= typemax (Int32) ? Int32 : Int
861
- args1 = (c, I (c_len), one (I), one (I), one (I), by, lt, Val (rev))
911
+ args1 = (c, I (c_len), one (I), one (I), one (I), by, lt, Val (rev), Val (dims) )
862
912
kernel1 = @cuda launch= false comparator_small_kernel (args1... )
913
+
863
914
config1 = launch_configuration (kernel1. fun, shmem = threads -> bitonic_shmem (c, threads))
864
- threads1 = prevpow (2 , config1. threads)
865
- args2 = (c, I (c_len), one (I), one (I), by, lt, Val (rev))
915
+ args2 = (c, I (c_len), one (I), one (I), by, lt, Val (rev), Val (dims))
866
916
kernel2 = @cuda launch= false comparator_kernel (args2... )
867
917
config2 = launch_configuration (kernel2. fun, shmem = threads -> bitonic_shmem (c, threads))
868
918
# blocksize for kernel2 MUST be a power of 2
@@ -877,26 +927,30 @@ function bitonic_sort!(c; by = identity, lt = isless, rev = false)
877
927
k0 = ceil (Int, log2 (c_len))
878
928
for k = k0: - 1 : 1
879
929
j_final = 1 + k0 - k
930
+
931
+ # non-sorting dims are put into blocks along grid y/z. Using sqrt minimizes wasted blocks
932
+ other_block_dims = Int (ceil (sqrt (otherdims_len))), Int (ceil (sqrt (otherdims_len)))
933
+
880
934
for j = 1 : j_final
881
- args1 = (c, I .((c_len, k, j, j_final))... , by, lt, Val (rev))
882
- args2 = (c, I .((c_len, k, j))... , by, lt, Val (rev))
935
+ args1 = (c, I .((c_len, k, j, j_final))... , by, lt, Val (rev), Val (dims) )
936
+ args2 = (c, I .((c_len, k, j))... , by, lt, Val (rev), Val (dims) )
883
937
if k0 - k - j + 2 <= log_threads
884
- # pseudo_block_length = max(nextpow(2, length(comparator)) for all comparators in this layer of the network)
938
+ # pseudo_block_length = max(nextpow(2, length(comparator))
939
+ # for all comparators in this layer of the network)
885
940
pseudo_block_length = 1 << abs (j_final + 1 - j)
886
941
# N_pseudo_blocks = how many pseudo-blocks are in this layer of the network
887
942
N_pseudo_blocks = nextpow (2 , c_len) ÷ pseudo_block_length
888
943
pseudo_blocks_per_block = threads2 ÷ pseudo_block_length
889
944
890
945
# grid dimensions
891
- N_blocks = max (1 , N_pseudo_blocks ÷ pseudo_blocks_per_block)
946
+ N_blocks = max (1 , N_pseudo_blocks ÷ pseudo_blocks_per_block), other_block_dims ...
892
947
block_size = pseudo_block_length, threads2 ÷ pseudo_block_length
893
-
894
948
kernel1 (args1... ; blocks= N_blocks, threads= block_size,
895
949
shmem= bitonic_shmem (c, block_size))
896
950
break
897
951
else
898
- N_blocks = cld (c_len, threads1)
899
- kernel2 (args2... ; blocks = N_blocks, threads= threads1 )
952
+ N_blocks = cld (c_len, threads2), other_block_dims ...
953
+ kernel2 (args2... ; blocks = N_blocks, threads= threads2 )
900
954
end
901
955
end
902
956
end
@@ -930,54 +984,58 @@ function Base.sort!(c::AnyCuVector, alg::QuickSortAlg; lt=isless, by=identity, r
930
984
return c
931
985
end
932
986
933
- function Base. sort! (c:: AnyCuVector , alg:: BitonicSortAlg ; kwargs... )
987
+ function Base. sort! (c:: AnyCuArray , alg:: BitonicSortAlg ; kwargs... )
934
988
return bitonic_sort! (c; kwargs... )
935
989
end
936
990
937
- function Base. sort! (c:: AnyCuVector ; alg :: SortingAlgorithm = BitonicSort, kwargs... )
991
+ function Base. sort! (c:: AnyCuArray ; alg:: SortingAlgorithm = BitonicSort, kwargs... )
938
992
return sort! (c, alg; kwargs... )
939
993
end
940
994
941
- function Base. sort! (c:: AnyCuArray ; dims:: Integer , lt= isless, by= identity, rev= false )
942
- # for multi dim sorting, only quicksort is supported so no alg keyword
943
- if rev
944
- lt = ! lt
945
- end
946
-
947
- quicksort! (c; lt, by, dims)
948
- return c
949
- end
950
-
951
995
function Base. sort (c:: AnyCuArray ; kwargs... )
952
996
return sort! (copy (c); kwargs... )
953
997
end
954
998
955
- function Base. partialsort! (c:: AnyCuVector , k:: Union{Integer, OrdinalRange} ;
956
- lt= isless, by= identity, rev= false )
999
+ function Base. partialsort! (c:: AnyCuVector , k:: Union{Integer, OrdinalRange} ,
1000
+ alg:: BitonicSortAlg ; lt= isless, by= identity, rev= false )
1001
+
1002
+ sort! (c, alg; lt, by, rev)
1003
+ return @allowscalar copy (c[k])
1004
+ end
1005
+
1006
+ function Base. partialsort! (c:: AnyCuVector , k:: Union{Integer, OrdinalRange} ,
1007
+ alg:: QuickSortAlg ; lt= isless, by= identity, rev= false )
957
1008
# for reverse sorting, invert the less-than function
958
1009
if rev
959
1010
lt = ! lt
960
1011
end
961
1012
962
- function out (k :: OrdinalRange )
1013
+ function out (k:: OrdinalRange )
963
1014
return copy (c[k])
964
1015
end
965
1016
966
1017
# work around disallowed scalar index
967
- function out (k :: Integer )
1018
+ function out (k:: Integer )
968
1019
return Array (c[k: k])[1 ]
969
1020
end
970
1021
971
1022
quicksort! (c; lt, by, dims= 1 , partial_k= k)
972
1023
return out (k)
973
1024
end
974
1025
1026
+ function Base. partialsort! (c:: AnyCuArray , k:: Union{Integer, OrdinalRange} ;
1027
+ alg:: SortingAlgorithm = BitonicSort, kwargs... )
1028
+ return partialsort! (c, k, alg; kwargs... )
1029
+ end
1030
+
975
1031
function Base. partialsort (c:: AnyCuArray , k:: Union{Integer, OrdinalRange} ; kwargs... )
976
1032
return partialsort! (copy (c), k; kwargs... )
977
1033
end
978
1034
979
- function Base. sortperm! (ix:: AnyCuArray{T} , A:: AnyCuArray ; initialized= false , kwargs... ) where T
980
- axes (ix) == axes (A) || throw (ArgumentError (" index array must have the same size/axes as the source array, $(axes (ix)) != $(axes (A)) " ))
1035
+ function Base. sortperm! (ix:: AnyCuArray , A:: AnyCuArray ; initialized= false , kwargs... )
1036
+ if axes (ix) != axes (A)
1037
+ throw (ArgumentError (" index array must have the same size/axes as the source array, $(axes (ix)) != $(axes (A)) " ))
1038
+ end
981
1039
982
1040
if ! initialized
983
1041
ix .= LinearIndices (A)
@@ -986,6 +1044,11 @@ function Base.sortperm!(ix::AnyCuArray{T}, A::AnyCuArray; initialized=false, kwa
986
1044
return ix
987
1045
end
988
1046
989
- function Base. sortperm (c:: AnyCuArray ; kwargs... )
1047
+ function Base. sortperm (c:: AnyCuVector ; kwargs... )
990
1048
sortperm! (CuArray (1 : length (c)), c; initialized= true , kwargs... )
991
1049
end
1050
+
1051
+ function Base. sortperm (c:: AnyCuArray ; dims, kwargs... )
1052
+ # Base errors for Matrices without dims arg, we should too
1053
+ sortperm! (reshape (CuArray (1 : length (c)), size (c)), c; initialized= true , dims, kwargs... )
1054
+ end
0 commit comments