Skip to content

Commit 70290d6

Browse files
Lilith HafnerLilith Hafner
authored andcommitted
fix unexpected allocations in Radix Sort
fixes #47474 in this PR rather than separate to avoid dealing with the merge
1 parent 05de36e commit 70290d6

File tree

1 file changed

+38
-18
lines changed

1 file changed

+38
-18
lines changed

base/sort.jl

Lines changed: 38 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -864,17 +864,30 @@ function _sort!(v::AbstractVector, a::RadixSort, o::DirectOrdering, kw)
864864

865865
len = hi-lo + 1
866866
U = UIntMappable(eltype(v), o)
867+
# A large if-else chain to avoid type instabilities and dynamic dispatch
867868
if scratch !== nothing && checkbounds(Bool, scratch, lo:hi) # Fully preallocated and aligned scratch
868-
u2 = radix_sort!(u, lo, hi, bits, reinterpret(U, scratch))
869-
uint_unmap!(v, u2, lo, hi, o, umn)
869+
t = reinterpret(U, scratch)
870+
if radix_sort!(u, lo, hi, bits, t)
871+
uint_unmap!(v, u, lo, hi, o, umn)
872+
else
873+
uint_unmap!(v, t, lo, hi, o, umn)
874+
end
870875
elseif scratch !== nothing && (applicable(resize!, scratch, len) || length(scratch) >= len) # Viable scratch
871876
length(scratch) >= len || resize!(scratch, len)
872877
t1 = axes(scratch, 1) isa OneTo ? scratch : view(scratch, firstindex(scratch):lastindex(scratch))
873-
u2 = radix_sort!(view(u, lo:hi), 1, len, bits, reinterpret(U, t1))
874-
uint_unmap!(view(v, lo:hi), u2, 1, len, o, umn)
878+
t = reinterpret(U, t1)
879+
if radix_sort!(view(u, lo:hi), 1, len, bits, t)
880+
uint_unmap!(view(v, lo:hi), view(u, lo:hi), 1, len, o, umn)
881+
else
882+
uint_unmap!(view(v, lo:hi), t, 1, len, o, umn)
883+
end
875884
else # No viable scratch
876-
u2 = radix_sort!(u, lo, hi, bits, similar(u))
877-
uint_unmap!(v, u2, lo, hi, o, umn)
885+
t = similar(u)
886+
if radix_sort!(u, lo, hi, bits, t)
887+
uint_unmap!(v, u, lo, hi, o, umn)
888+
else
889+
uint_unmap!(v, t, lo, hi, o, umn)
890+
end
878891
end
879892
end
880893

@@ -1025,16 +1038,28 @@ function _sort!(v::AbstractVector, a::StableCheckSorted, o::Ordering, kw)
10251038
end
10261039

10271040

1028-
# In the case of an odd number of passes, the returned vector will === the input vector t,
1029-
# not v. This is one of the many reasons radix_sort! is not exported.
1041+
# The return value indicates whether v is sorted (true) or t is sorted (false)
1042+
# This is one of the many reasons radix_sort! is not exported.
10301043
function radix_sort!(v::AbstractVector{U}, lo::Integer, hi::Integer, bits::Unsigned,
10311044
t::AbstractVector{U}, chunk_size=radix_chunk_size_heuristic(lo, hi, bits)) where U <: Unsigned
10321045
# bits is unsigned for performance reasons.
1033-
mask = UInt(1) << chunk_size - 1
1034-
counts = Vector{Int}(undef, mask+2)
1035-
1036-
@inbounds for shift in 0:chunk_size:bits-1
1037-
1046+
counts = Vector{Int}(undef, 1 << chunk_size + 1)
1047+
1048+
shift = 0
1049+
while true
1050+
@noinline radix_sort_pass!(t, lo, hi, counts, v, shift, chunk_size)
1051+
# the latest data resides in t
1052+
shift += chunk_size
1053+
shift < bits || return false
1054+
@noinline radix_sort_pass!(v, lo, hi, counts, t, shift, chunk_size)
1055+
# the latest data resides in v
1056+
shift += chunk_size
1057+
shift < bits || return true
1058+
end
1059+
end
1060+
function radix_sort_pass!(t, lo, hi, counts, v, shift, chunk_size)
1061+
mask = UInt(1) << chunk_size - 1 # mask is defined in pass so that the compiler
1062+
@inbounds begin # ↳ knows it's shape
10381063
# counts[2:mask+2] will store the number of elements that fall into each bucket.
10391064
# if chunk_size = 8, counts[2] is bucket 0x00 and counts[257] is bucket 0xff.
10401065
counts .= 0
@@ -1058,12 +1083,7 @@ function radix_sort!(v::AbstractVector{U}, lo::Integer, hi::Integer, bits::Unsig
10581083
t[j] = x # put the element where it belongs
10591084
counts[i] = j + 1 # increment the target index for the next
10601085
end # ↳ element in this bucket
1061-
1062-
v, t = t, v # swap the now sorted destination vector t back into primary vector v
1063-
10641086
end
1065-
1066-
v
10671087
end
10681088
function radix_chunk_size_heuristic(lo::Integer, hi::Integer, bits::Unsigned)
10691089
# chunk_size is the number of bits to radix over at once.

0 commit comments

Comments
 (0)