Skip to content

Commit 4a8ea8c

Browse files
authored
Sign-aware computation of midpoint for sorting (fixes #33977) (#34106)
1 parent 6f76d16 commit 4a8ea8c

File tree

3 files changed

+37
-5
lines changed

3 files changed

+37
-5
lines changed

base/sort.jl

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,10 @@ same thing as `partialsort!` but leaving `v` unmodified.
164164
partialsort(v::AbstractVector, k::Union{Integer,OrdinalRange}; kws...) =
165165
partialsort!(copymutable(v), k; kws...)
166166

167+
# This implementation of `midpoint` is performance-optimized but safe
168+
# only if `lo <= hi`.
169+
midpoint(lo::T, hi::T) where T<:Integer = lo + ((hi - lo) >>> 0x01)
170+
midpoint(lo::Integer, hi::Integer) = midpoint(promote(lo, hi)...)
167171

168172
# reference on sorted binary search:
169173
# http://www.tbray.org/ongoing/When/200x/2003/03/22/Binary
@@ -175,7 +179,7 @@ function searchsortedfirst(v::AbstractVector, x, lo::T, hi::T, o::Ordering) wher
175179
lo = lo - u
176180
hi = hi + u
177181
@inbounds while lo < hi - u
178-
m = (lo + hi) >>> 1
182+
m = midpoint(lo, hi)
179183
if lt(o, v[m], x)
180184
lo = m
181185
else
@@ -192,7 +196,7 @@ function searchsortedlast(v::AbstractVector, x, lo::T, hi::T, o::Ordering) where
192196
lo = lo - u
193197
hi = hi + u
194198
@inbounds while lo < hi - u
195-
m = (lo + hi) >>> 1
199+
m = midpoint(lo, hi)
196200
if lt(o, x, v[m])
197201
hi = m
198202
else
@@ -210,7 +214,7 @@ function searchsorted(v::AbstractVector, x, ilo::T, ihi::T, o::Ordering) where T
210214
lo = ilo - u
211215
hi = ihi + u
212216
@inbounds while lo < hi - u
213-
m = (lo + hi) >>> 1
217+
m = midpoint(lo, hi)
214218
if lt(o, v[m], x)
215219
lo = m
216220
elseif lt(o, x, v[m])
@@ -487,7 +491,7 @@ end
487491

488492
@inline function selectpivot!(v::AbstractVector, lo::Integer, hi::Integer, o::Ordering)
489493
@inbounds begin
490-
mi = (lo+hi)>>>1
494+
mi = midpoint(lo, hi)
491495

492496
# sort v[mi] <= v[lo] <= v[hi] such that the pivot is immediately in place
493497
if lt(o, v[lo], v[mi])
@@ -552,7 +556,7 @@ function sort!(v::AbstractVector, lo::Integer, hi::Integer, a::MergeSortAlg, o::
552556
@inbounds if lo < hi
553557
hi-lo <= SMALL_THRESHOLD && return sort!(v, lo, hi, SMALL_ALGORITHM, o)
554558

555-
m = (lo+hi)>>>1
559+
m = midpoint(lo, hi)
556560
(length(t) < m-lo+1) && resize!(t, m-lo+1)
557561

558562
sort!(v, lo, m, a, o, t)

test/offsetarray.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -475,6 +475,19 @@ v = OffsetArray(rand(8), (-2,))
475475
@test sortslices(A, dims=2) == OffsetArray(sortslices(parent(A), dims=2), A.offsets)
476476
@test sort(A, dims=1) == OffsetArray(sort(parent(A), dims=1), A.offsets)
477477
@test sort(A, dims=2) == OffsetArray(sort(parent(A), dims=2), A.offsets)
478+
# Issue #33977
479+
soa = OffsetArray([2,2,3], -2)
480+
@test searchsorted(soa, 1) == -1:-2
481+
@test searchsortedfirst(soa, 1) == -1
482+
@test searchsortedlast(soa, 1) == -2
483+
@test first(sort!(soa; alg=QuickSort)) == 2
484+
@test first(sort!(soa; alg=MergeSort)) == 2
485+
soa = OffsetArray([2,2,3], typemax(Int)-4)
486+
@test searchsorted(soa, 1) == typemax(Int)-3:typemax(Int)-4
487+
@test searchsortedfirst(soa, 2) == typemax(Int) - 3
488+
@test searchsortedlast(soa, 2) == typemax(Int) - 2
489+
@test first(sort!(soa; alg=QuickSort)) == 2
490+
@test first(sort!(soa; alg=MergeSort)) == 2
478491

479492
@test mapslices(sort, A, dims=1) == OffsetArray(mapslices(sort, parent(A), dims=1), A.offsets)
480493
@test mapslices(sort, A, dims=2) == OffsetArray(mapslices(sort, parent(A), dims=2), A.offsets)

test/sorting.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,21 @@ using Test
1111
@test ReverseOrdering(Forward) == ReverseOrdering() == Reverse
1212
end
1313

14+
@testset "midpoint" begin
15+
@test Base.Sort.midpoint(1, 3) === 2
16+
@test Base.Sort.midpoint(2, 4) === 3
17+
@test 2 <= Base.Sort.midpoint(1, 4) <= 3
18+
@test Base.Sort.midpoint(-3, -1) === -2
19+
@test Base.Sort.midpoint(-4, -2) === -3
20+
@test -3 <= Base.Sort.midpoint(-4, -1) <= -2
21+
@test Base.Sort.midpoint(-1, 1) === 0
22+
@test -1 <= Base.Sort.midpoint(-2, 1) <= 0
23+
@test 0 <= Base.Sort.midpoint(-1, 2) <= 1
24+
@test Base.Sort.midpoint(-2, 2) === 0
25+
@test Base.Sort.midpoint(typemax(Int)-2, typemax(Int)) === typemax(Int)-1
26+
@test Base.Sort.midpoint(typemin(Int), typemin(Int)+2) === typemin(Int)+1
27+
@test -1 <= Base.Sort.midpoint(typemin(Int), typemax(Int)) <= 0
28+
end
1429

1530
@testset "sort" begin
1631
@test sort([2,3,1]) == [1,2,3] == sort([2,3,1]; order=Forward)

0 commit comments

Comments
 (0)