Skip to content

Commit b3e1bdf

Browse files
author
Alex Ellison
authored
Sorting: sortperm(; dims) and bitonic partialsort (#2308)
1 parent 977177f commit b3e1bdf

File tree

2 files changed

+121
-51
lines changed

2 files changed

+121
-51
lines changed

src/sorting.jl

Lines changed: 113 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,8 @@ affected by `parity`. See `flex_lt`. `swap` is an array for exchanging values
6969
and `sums` is an array of Ints used during the merge sort.
7070
Uses block y index to decide which values to operate on.
7171
"""
72-
@inline function batch_partition(values, pivot, swap, sums, lo, hi, parity, lt::F1, by::F2) where {F1,F2}
72+
@inline function batch_partition(values, pivot, swap, sums, lo, hi, parity,
73+
lt::F1, by::F2) where {F1,F2}
7374
sync_threads()
7475
blockIdx_yz = (blockIdx().z - 1i32) * gridDim().y + blockIdx().y
7576
idx0 = lo + (blockIdx_yz - 1i32) * blockDim().x + threadIdx().x
@@ -88,7 +89,11 @@ Uses block y index to decide which values to operate on.
8889
cumsum!(sums)
8990

9091
@inbounds if idx0 <= hi
91-
dest_idx = @inbounds comparison ? blockDim().x - sums[end] + sums[threadIdx().x] : threadIdx().x - sums[threadIdx().x]
92+
dest_idx = @inbounds if comparison
93+
blockDim().x - sums[end] + sums[threadIdx().x]
94+
else
95+
threadIdx().x - sums[threadIdx().x]
96+
end
9297
if dest_idx <= length(swap)
9398
swap[dest_idx] = val
9499
end
@@ -211,7 +216,7 @@ Finds the median of `vals` starting after `lo` and going for `blockDim().x`
211216
elements spaced by `stride`. Performs bitonic sort in shmem, returns middle value.
212217
Faster than bubble sort, but not as flexible. Does not modify `vals`
213218
"""
214-
function bitonic_median(vals :: AbstractArray{T}, swap, lo, L, stride, lt::F1, by::F2) where {T,F1,F2}
219+
function bitonic_median(vals::AbstractArray{T}, swap, lo, L, stride, lt::F1, by::F2) where {T,F1,F2}
215220
sync_threads()
216221
bitonic_lt(i1, i2) = @inbounds flex_lt(swap[i1 + 1], swap[i2 + 1], false, lt, by)
217222

@@ -337,7 +342,7 @@ Quicksort recursion condition
337342
For a full sort, `partial` is nothing so it shouldn't affect whether recursion
338343
happens.
339344
"""
340-
function partial_range_overlap(lo, hi, partial :: Nothing)
345+
function partial_range_overlap(lo, hi, partial::Nothing)
341346
true
342347
end
343348

@@ -374,7 +379,8 @@ it's possible that the first pivot will be that value, which could lead to an in
374379
early end to recursion if we started `stuck` at 0.
375380
"""
376381
function qsort_kernel(vals::AbstractArray{T,N}, lo, hi, parity, sync::Val{S}, sync_depth,
377-
prev_pivot, lt::F1, by::F2, ::Val{dims}, partial=nothing, stuck=-1) where {T, N, S, F1, F2, dims}
382+
prev_pivot, lt::F1, by::F2, ::Val{dims}, partial=nothing,
383+
stuck=-1) where {T, N, S, F1, F2, dims}
378384
b_sums = CuDynamicSharedArray(Int, blockDim().x)
379385
swap = CuDynamicSharedArray(T, blockDim().x, sizeof(b_sums))
380386
shmem = sizeof(b_sums) + sizeof(swap)
@@ -449,7 +455,7 @@ function qsort_kernel(vals::AbstractArray{T,N}, lo, hi, parity, sync::Val{S}, sy
449455
return
450456
end
451457

452-
function sort_args(args, partial_k :: Nothing)
458+
function sort_args(args, partial_k::Nothing)
453459
return args
454460
end
455461

@@ -578,10 +584,43 @@ end
578584
end
579585
end
580586

587+
@inline function extraneous_block(vals::AbstractArray, dims):: Bool
588+
other_linear_index = ((gridDim().z ÷ blockDim().z) * (blockIdx().y - 1)) + blockIdx().z
589+
return other_linear_index > length(vals) ÷ size(vals)[dims]
590+
end
591+
592+
@inline function extraneous_block(vals, dims)::Bool
593+
return extraneous_block(vals[1], dims)
594+
end
595+
596+
# methods are defined for Val{1} because using view has 2x speed penalty for 1D arrays
597+
@inline function view_along_dims(vals::AbstractArray{T, 1}, dimsval::Val{1}) where T
598+
return vals
599+
end
600+
601+
@inline function view_along_dims(vals::Tuple{AbstractArray{T,1},Any}, dimsval::Val{1}) where T
602+
return vals[1], view_along_dims(vals[2], dimsval)
603+
end
604+
605+
606+
@inline function view_along_dims(vals::AbstractArray{T, N}, ::Val{dims}) where {T,N,dims}
607+
otherdims = ntuple(i -> i == dims ? 1 : size(vals, i), N)
608+
other_linear_index = ((gridDim().z ÷ blockDim().z) * (blockIdx().y - 1)) + blockIdx().z
609+
other = CartesianIndices(otherdims)[other_linear_index]
610+
# create a view that keeps the sorting dimension but indexes across the others
611+
slicedims = map(Base.Slice, axes(vals))
612+
idxs = ntuple(i->i==dims ? slicedims[i] : other[i], N)
613+
return view(vals, idxs...)
614+
end
615+
616+
@inline function view_along_dims(vals, dimsval::Val{dims}) where dims
617+
return vals[1], view_along_dims(vals[2], dimsval)
618+
end
619+
620+
581621
# Functions specifically for "large" bitonic steps (those that cannot use shmem)
582622

583-
@inline function compare!(vals::AbstractArray{T}, i1::I, i2::I, dir::Bool, by, lt,
584-
rev) where {T,I}
623+
@inline function compare!(vals::AbstractArray{T, N}, i1::I, i2::I, dir::Bool, by, lt, rev) where {T,I,N}
585624
i1′, i2′ = i1 + one(I), i2 + one(I)
586625
@inbounds if dir != rev_lt(by(vals[i1′]), by(vals[i2′]), lt, rev)
587626
vals[i1′], vals[i2′] = vals[i2′], vals[i1′]
@@ -593,7 +632,8 @@ end
593632
i1′, i2′ = i1 + one(I), i2 + one(I)
594633
vals, inds = vals_inds
595634
# comparing tuples of (value, index) guarantees stability of sort
596-
@inbounds if dir != rev_lt((by(vals[inds[i1′]]), inds[i1′]), (by(vals[inds[i2′]]), inds[i2′]), lt, rev)
635+
@inbounds if dir != rev_lt((by(vals[inds[i1′]]), inds[i1′]),
636+
(by(vals[inds[i2′]]), inds[i2′]), lt, rev)
597637
inds[i1′], inds[i2′] = inds[i2′], inds[i1′]
598638
end
599639
end
@@ -645,16 +685,22 @@ Note that to avoid synchronization issues, only one thread from each pair of
645685
indices being swapped will actually move data.
646686
"""
647687
function comparator_kernel(vals, length_vals::I, k::I, j::I, by::F1, lt::F2,
648-
rev) where {I,F1,F2}
688+
rev, dimsval::Val{dims}) where {I,F1,F2,dims}
689+
if extraneous_block(vals, dims)
690+
return nothing
691+
end
692+
649693
index = (blockDim().x * (blockIdx().x - one(I))) + threadIdx().x - one(I)
650694

695+
slice = view_along_dims(vals, dimsval)
696+
651697
lo, n, dir = get_range(length_vals, index, k, j)
652698

653699
if !(lo < zero(I) || n < zero(I)) && !(index >= length_vals)
654700
m = gp2lt(n)
655701
if lo <= index < lo + n - m
656702
i1, i2 = index, index + m
657-
@inbounds compare!(vals, i1, i2, dir, by, lt, rev)
703+
@inbounds compare!(slice, i1, i2, dir, by, lt, rev)
658704
end
659705
end
660706
return
@@ -804,15 +850,19 @@ a unique range of indices corresponding to a comparator in the sorting network.
804850
Note that this moves the array values copied within shmem, but doesn't copy them
805851
back to global the way it does for indices.
806852
"""
807-
function comparator_small_kernel(c, length_c::I, k::I, j_0::I, j_f::I,
808-
by::F1, lt::F2, rev) where {I,F1,F2}
853+
function comparator_small_kernel(vals, length_vals::I, k::I, j_0::I, j_f::I,
854+
by::F1, lt::F2, rev, dimsval::Val{dims}) where {I,F1,F2,dims}
855+
if extraneous_block(vals, dims)
856+
return nothing
857+
end
858+
slice = view_along_dims(vals, dimsval)
809859
pseudo_block_idx = (blockIdx().x - one(I)) * blockDim().y + threadIdx().y - one(I)
810860
# immutable info about the range used by this kernel
811-
_lo, _n, dir = block_range(length_c, pseudo_block_idx, k, j_0)
861+
_lo, _n, dir = block_range(length_vals, pseudo_block_idx, k, j_0)
812862
index = _lo + threadIdx().x - one(I)
813863
in_range = (threadIdx().x <= _n && _lo >= zero(I))
814864

815-
swap = initialize_shmem!(c, index, in_range)
865+
swap = initialize_shmem!(slice, index, in_range)
816866

817867
# mutable copies for pseudo-recursion
818868
lo, n = _lo, _n
@@ -829,7 +879,7 @@ function comparator_small_kernel(c, length_c::I, k::I, j_0::I, j_f::I,
829879
sync_threads()
830880
end
831881

832-
finalize_shmem!(c, swap, index, in_range)
882+
finalize_shmem!(slice, swap, index, in_range)
833883
return
834884
end
835885

@@ -849,20 +899,20 @@ of values and an index array for doing `sortperm!`. Cannot provide a stable
849899
`sort!` although `sortperm!` is properly stable. To reverse, set `rev=true`
850900
rather than `lt=!isless` (otherwise stability of sortperm breaks down).
851901
"""
852-
function bitonic_sort!(c; by = identity, lt = isless, rev = false)
853-
c_len = if typeof(c) <: Tuple
854-
length(c[1])
902+
function bitonic_sort!(c; by = identity, lt = isless, rev = false, dims=1)
903+
c_len, otherdims_len = if typeof(c) <: Tuple
904+
size(c[1])[dims], length(c[1]) ÷ size(c[1])[dims]
855905
else
856-
length(c)
906+
size(c)[dims], length(c) ÷ size(c)[dims]
857907
end
858908

859-
# compile kernels (using Int32 for indexing, if possible, yielding a 10% speedup)
909+
# compile kernels (using Int32 for indexing, if possible, yielding a 70% speedup)
860910
I = c_len <= typemax(Int32) ? Int32 : Int
861-
args1 = (c, I(c_len), one(I), one(I), one(I), by, lt, Val(rev))
911+
args1 = (c, I(c_len), one(I), one(I), one(I), by, lt, Val(rev), Val(dims))
862912
kernel1 = @cuda launch=false comparator_small_kernel(args1...)
913+
863914
config1 = launch_configuration(kernel1.fun, shmem = threads -> bitonic_shmem(c, threads))
864-
threads1 = prevpow(2, config1.threads)
865-
args2 = (c, I(c_len), one(I), one(I), by, lt, Val(rev))
915+
args2 = (c, I(c_len), one(I), one(I), by, lt, Val(rev), Val(dims))
866916
kernel2 = @cuda launch=false comparator_kernel(args2...)
867917
config2 = launch_configuration(kernel2.fun, shmem = threads -> bitonic_shmem(c, threads))
868918
# blocksize for kernel2 MUST be a power of 2
@@ -877,26 +927,30 @@ function bitonic_sort!(c; by = identity, lt = isless, rev = false)
877927
k0 = ceil(Int, log2(c_len))
878928
for k = k0:-1:1
879929
j_final = 1 + k0 - k
930+
931+
# non-sorting dims are put into blocks along grid y/z. Using sqrt minimizes wasted blocks
932+
other_block_dims = Int(ceil(sqrt(otherdims_len))), Int(ceil(sqrt(otherdims_len)))
933+
880934
for j = 1:j_final
881-
args1 = (c, I.((c_len, k, j, j_final))..., by, lt, Val(rev))
882-
args2 = (c, I.((c_len, k, j))..., by, lt, Val(rev))
935+
args1 = (c, I.((c_len, k, j, j_final))..., by, lt, Val(rev), Val(dims))
936+
args2 = (c, I.((c_len, k, j))..., by, lt, Val(rev), Val(dims))
883937
if k0 - k - j + 2 <= log_threads
884-
# pseudo_block_length = max(nextpow(2, length(comparator)) for all comparators in this layer of the network)
938+
# pseudo_block_length = max(nextpow(2, length(comparator))
939+
# for all comparators in this layer of the network)
885940
pseudo_block_length = 1 << abs(j_final + 1 - j)
886941
# N_pseudo_blocks = how many pseudo-blocks are in this layer of the network
887942
N_pseudo_blocks = nextpow(2, c_len) ÷ pseudo_block_length
888943
pseudo_blocks_per_block = threads2 ÷ pseudo_block_length
889944

890945
# grid dimensions
891-
N_blocks = max(1, N_pseudo_blocks ÷ pseudo_blocks_per_block)
946+
N_blocks = max(1, N_pseudo_blocks ÷ pseudo_blocks_per_block), other_block_dims...
892947
block_size = pseudo_block_length, threads2 ÷ pseudo_block_length
893-
894948
kernel1(args1...; blocks=N_blocks, threads=block_size,
895949
shmem=bitonic_shmem(c, block_size))
896950
break
897951
else
898-
N_blocks = cld(c_len, threads1)
899-
kernel2(args2...; blocks = N_blocks, threads=threads1)
952+
N_blocks = cld(c_len, threads2), other_block_dims...
953+
kernel2(args2...; blocks = N_blocks, threads=threads2)
900954
end
901955
end
902956
end
@@ -930,54 +984,58 @@ function Base.sort!(c::AnyCuVector, alg::QuickSortAlg; lt=isless, by=identity, r
930984
return c
931985
end
932986

933-
function Base.sort!(c::AnyCuVector, alg::BitonicSortAlg; kwargs...)
987+
function Base.sort!(c::AnyCuArray, alg::BitonicSortAlg; kwargs...)
934988
return bitonic_sort!(c; kwargs...)
935989
end
936990

937-
function Base.sort!(c::AnyCuVector; alg :: SortingAlgorithm = BitonicSort, kwargs...)
991+
function Base.sort!(c::AnyCuArray; alg::SortingAlgorithm = BitonicSort, kwargs...)
938992
return sort!(c, alg; kwargs...)
939993
end
940994

941-
function Base.sort!(c::AnyCuArray; dims::Integer, lt=isless, by=identity, rev=false)
942-
# for multi dim sorting, only quicksort is supported so no alg keyword
943-
if rev
944-
lt = !lt
945-
end
946-
947-
quicksort!(c; lt, by, dims)
948-
return c
949-
end
950-
951995
function Base.sort(c::AnyCuArray; kwargs...)
952996
return sort!(copy(c); kwargs...)
953997
end
954998

955-
function Base.partialsort!(c::AnyCuVector, k::Union{Integer, OrdinalRange};
956-
lt=isless, by=identity, rev=false)
999+
function Base.partialsort!(c::AnyCuVector, k::Union{Integer, OrdinalRange},
1000+
alg::BitonicSortAlg; lt=isless, by=identity, rev=false)
1001+
1002+
sort!(c, alg; lt, by, rev)
1003+
return @allowscalar copy(c[k])
1004+
end
1005+
1006+
function Base.partialsort!(c::AnyCuVector, k::Union{Integer, OrdinalRange},
1007+
alg::QuickSortAlg; lt=isless, by=identity, rev=false)
9571008
# for reverse sorting, invert the less-than function
9581009
if rev
9591010
lt = !lt
9601011
end
9611012

962-
function out(k :: OrdinalRange)
1013+
function out(k::OrdinalRange)
9631014
return copy(c[k])
9641015
end
9651016

9661017
# work around disallowed scalar index
967-
function out(k :: Integer)
1018+
function out(k::Integer)
9681019
return Array(c[k:k])[1]
9691020
end
9701021

9711022
quicksort!(c; lt, by, dims=1, partial_k=k)
9721023
return out(k)
9731024
end
9741025

1026+
function Base.partialsort!(c::AnyCuArray, k::Union{Integer, OrdinalRange};
1027+
alg::SortingAlgorithm=BitonicSort, kwargs...)
1028+
return partialsort!(c, k, alg; kwargs...)
1029+
end
1030+
9751031
function Base.partialsort(c::AnyCuArray, k::Union{Integer, OrdinalRange}; kwargs...)
9761032
return partialsort!(copy(c), k; kwargs...)
9771033
end
9781034

979-
function Base.sortperm!(ix::AnyCuArray{T}, A::AnyCuArray; initialized=false, kwargs...) where T
980-
axes(ix) == axes(A) || throw(ArgumentError("index array must have the same size/axes as the source array, $(axes(ix)) != $(axes(A))"))
1035+
function Base.sortperm!(ix::AnyCuArray, A::AnyCuArray; initialized=false, kwargs...)
1036+
if axes(ix) != axes(A)
1037+
throw(ArgumentError("index array must have the same size/axes as the source array, $(axes(ix)) != $(axes(A))"))
1038+
end
9811039

9821040
if !initialized
9831041
ix .= LinearIndices(A)
@@ -986,6 +1044,11 @@ function Base.sortperm!(ix::AnyCuArray{T}, A::AnyCuArray; initialized=false, kwa
9861044
return ix
9871045
end
9881046

989-
function Base.sortperm(c::AnyCuArray; kwargs...)
1047+
function Base.sortperm(c::AnyCuVector; kwargs...)
9901048
sortperm!(CuArray(1:length(c)), c; initialized=true, kwargs...)
9911049
end
1050+
1051+
function Base.sortperm(c::AnyCuArray; dims, kwargs...)
1052+
# Base errors for Matrices without dims arg, we should too
1053+
sortperm!(reshape(CuArray(1:length(c)), size(c)), c; initialized=true, dims, kwargs...)
1054+
end

test/base/sorting.jl

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -315,7 +315,7 @@ end
315315

316316
# multiple dimensions
317317
@test check_sort!(Int32, (4, 50000, 4); dims=2)
318-
@test check_sort!(Int32, (4, 4, 50000); dims=3, rev=true)
318+
@test check_sort!(Int32, (2, 2, 50000); dims=3, rev=true)
319319

320320
# large sizes
321321
@test check_sort!(Float32, 2^25; alg=CUDA.QuickSort)
@@ -389,6 +389,13 @@ end
389389
@test check_sortperm(Float64, 1000000; rev=true)
390390
@test check_sortperm(Float64, 1000000; by=x->abs(x-0.5))
391391
@test check_sortperm(Float64, 1000000; rev=true, by=x->abs(x-0.5))
392+
393+
if VERSION >= v"1.9"
394+
# Base.jl didn't implement sortperm(;dims) until 1.9
395+
@test check_sortperm(Float32, (100_000, 16); dims=1)
396+
@test check_sortperm(Float32, (100_000, 16); dims=2)
397+
@test check_sortperm(Float32, (100, 256, 256); dims=1)
398+
end
392399
# check with Int32 indices
393400
@test check_sortperm!(collect(Int32(1):Int32(1000000)), Float32, 1000000)
394401
# `initialized` kwarg

0 commit comments

Comments
 (0)