Skip to content

Commit 6248170

Browse files
authored
Make searchsorted*/findnext/findprev return values of keytype (#32978)
1 parent 0bab06f commit 6248170

File tree

10 files changed

+108
-23
lines changed

10 files changed

+108
-23
lines changed

base/array.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1669,7 +1669,7 @@ CartesianIndex(2, 1)
16691669
"""
16701670
function findnext(A, start)
16711671
l = last(keys(A))
1672-
i = start
1672+
i = oftype(l, start)
16731673
i > l && return nothing
16741674
while true
16751675
A[i] && return i
@@ -1751,7 +1751,7 @@ CartesianIndex(1, 1)
17511751
"""
17521752
function findnext(testf::Function, A, start)
17531753
l = last(keys(A))
1754-
i = start
1754+
i = oftype(l, start)
17551755
i > l && return nothing
17561756
while true
17571757
testf(A[i]) && return i
@@ -1855,8 +1855,8 @@ CartesianIndex(2, 1)
18551855
```
18561856
"""
18571857
function findprev(A, start)
1858-
i = start
18591858
f = first(keys(A))
1859+
i = oftype(f, start)
18601860
i < f && return nothing
18611861
while true
18621862
A[i] && return i
@@ -1946,8 +1946,8 @@ CartesianIndex(2, 1)
19461946
```
19471947
"""
19481948
function findprev(testf::Function, A, start)
1949-
i = start
19501949
f = first(keys(A))
1950+
i = oftype(f, start)
19511951
i < f && return nothing
19521952
while true
19531953
testf(A[i]) && return i

base/bitarray.jl

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1374,7 +1374,7 @@ end
13741374

13751375
count(B::BitArray) = bitcount(B.chunks)
13761376

1377-
function unsafe_bitfindnext(Bc::Vector{UInt64}, start::Integer)
1377+
function unsafe_bitfindnext(Bc::Vector{UInt64}, start::Int)
13781378
chunk_start = _div64(start-1)+1
13791379
within_chunk_start = _mod64(start-1)
13801380
mask = _msk64 << within_chunk_start
@@ -1397,13 +1397,14 @@ end
13971397
function findnext(B::BitArray, start::Integer)
13981398
start > 0 || throw(BoundsError(B, start))
13991399
start > length(B) && return nothing
1400-
unsafe_bitfindnext(B.chunks, start)
1400+
unsafe_bitfindnext(B.chunks, Int(start))
14011401
end
14021402

14031403
#findfirst(B::BitArray) = findnext(B, 1) ## defined in array.jl
14041404

14051405
# aux function: same as findnext(~B, start), but performed without temporaries
14061406
function findnextnot(B::BitArray, start::Integer)
1407+
start = Int(start)
14071408
start > 0 || throw(BoundsError(B, start))
14081409
start > length(B) && return nothing
14091410

@@ -1458,7 +1459,7 @@ function findnext(testf::Function, B::BitArray, start::Integer)
14581459
end
14591460
#findfirst(testf::Function, B::BitArray) = findnext(testf, B, 1) ## defined in array.jl
14601461

1461-
function unsafe_bitfindprev(Bc::Vector{UInt64}, start::Integer)
1462+
function unsafe_bitfindprev(Bc::Vector{UInt64}, start::Int)
14621463
chunk_start = _div64(start-1)+1
14631464
mask = _msk_end(start)
14641465

@@ -1480,10 +1481,11 @@ end
14801481
function findprev(B::BitArray, start::Integer)
14811482
start > 0 || return nothing
14821483
start > length(B) && throw(BoundsError(B, start))
1483-
unsafe_bitfindprev(B.chunks, start)
1484+
unsafe_bitfindprev(B.chunks, Int(start))
14841485
end
14851486

14861487
function findprevnot(B::BitArray, start::Integer)
1488+
start = Int(start)
14871489
start > 0 || return nothing
14881490
start > length(B) && throw(BoundsError(B, start))
14891491

base/sort.jl

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ using .Base: copymutable, LinearIndices, length, (:),
1010
AbstractVector, @inbounds, AbstractRange, @eval, @inline, Vector, @noinline,
1111
AbstractMatrix, AbstractUnitRange, isless, identity, eltype, >, <, <=, >=, |, +, -, *, !,
1212
extrema, sub_with_overflow, add_with_overflow, oneunit, div, getindex, setindex!,
13-
length, resize!, fill, Missing, require_one_based_indexing
13+
length, resize!, fill, Missing, require_one_based_indexing, keytype
1414

1515
using .Base: >>>, !==
1616

@@ -174,7 +174,7 @@ midpoint(lo::Integer, hi::Integer) = midpoint(promote(lo, hi)...)
174174

175175
# index of the first value of vector a that is greater than or equal to x;
176176
# returns length(v)+1 if x is greater than all values in v.
177-
function searchsortedfirst(v::AbstractVector, x, lo::T, hi::T, o::Ordering) where T<:Integer
177+
function searchsortedfirst(v::AbstractVector, x, lo::T, hi::T, o::Ordering)::keytype(v) where T<:Integer
178178
u = T(1)
179179
lo = lo - u
180180
hi = hi + u
@@ -191,7 +191,7 @@ end
191191

192192
# index of the last value of vector a that is less than or equal to x;
193193
# returns 0 if x is less than all values of v.
194-
function searchsortedlast(v::AbstractVector, x, lo::T, hi::T, o::Ordering) where T<:Integer
194+
function searchsortedlast(v::AbstractVector, x, lo::T, hi::T, o::Ordering)::keytype(v) where T<:Integer
195195
u = T(1)
196196
lo = lo - u
197197
hi = hi + u
@@ -209,7 +209,7 @@ end
209209
# returns the range of indices of v equal to x
210210
# if v does not contain x, returns a 0-length range
211211
# indicating the insertion point of x
212-
function searchsorted(v::AbstractVector, x, ilo::T, ihi::T, o::Ordering) where T<:Integer
212+
function searchsorted(v::AbstractVector, x, ilo::T, ihi::T, o::Ordering)::UnitRange{keytype(v)} where T<:Integer
213213
u = T(1)
214214
lo = ilo - u
215215
hi = ihi + u
@@ -228,7 +228,7 @@ function searchsorted(v::AbstractVector, x, ilo::T, ihi::T, o::Ordering) where T
228228
return (lo + 1) : (hi - 1)
229229
end
230230

231-
function searchsortedlast(a::AbstractRange{<:Real}, x::Real, o::DirectOrdering)
231+
function searchsortedlast(a::AbstractRange{<:Real}, x::Real, o::DirectOrdering)::keytype(a)
232232
require_one_based_indexing(a)
233233
if step(a) == 0
234234
lt(o, x, first(a)) ? 0 : length(a)
@@ -238,7 +238,7 @@ function searchsortedlast(a::AbstractRange{<:Real}, x::Real, o::DirectOrdering)
238238
end
239239
end
240240

241-
function searchsortedfirst(a::AbstractRange{<:Real}, x::Real, o::DirectOrdering)
241+
function searchsortedfirst(a::AbstractRange{<:Real}, x::Real, o::DirectOrdering)::keytype(a)
242242
require_one_based_indexing(a)
243243
if step(a) == 0
244244
lt(o, first(a), x) ? length(a) + 1 : 1
@@ -248,7 +248,7 @@ function searchsortedfirst(a::AbstractRange{<:Real}, x::Real, o::DirectOrdering)
248248
end
249249
end
250250

251-
function searchsortedlast(a::AbstractRange{<:Integer}, x::Real, o::DirectOrdering)
251+
function searchsortedlast(a::AbstractRange{<:Integer}, x::Real, o::DirectOrdering)::keytype(a)
252252
require_one_based_indexing(a)
253253
h = step(a)
254254
if h == 0
@@ -270,7 +270,7 @@ function searchsortedlast(a::AbstractRange{<:Integer}, x::Real, o::DirectOrderin
270270
end
271271
end
272272

273-
function searchsortedfirst(a::AbstractRange{<:Integer}, x::Real, o::DirectOrdering)
273+
function searchsortedfirst(a::AbstractRange{<:Integer}, x::Real, o::DirectOrdering)::keytype(a)
274274
require_one_based_indexing(a)
275275
h = step(a)
276276
if h == 0
@@ -285,14 +285,14 @@ function searchsortedfirst(a::AbstractRange{<:Integer}, x::Real, o::DirectOrderi
285285
lastindex(a) + 1
286286
else
287287
if o isa ForwardOrdering
288-
-fld(floor(Integer, -x) + first(a), h) + 1
288+
-fld(floor(Integer, -x) + Signed(first(a)), h) + 1
289289
else
290-
-fld(ceil(Integer, -x) + first(a), h) + 1
290+
-fld(ceil(Integer, -x) + Signed(first(a)), h) + 1
291291
end
292292
end
293293
end
294294

295-
function searchsortedfirst(a::AbstractRange{<:Integer}, x::Unsigned, o::DirectOrdering)
295+
function searchsortedfirst(a::AbstractRange{<:Integer}, x::Unsigned, o::DirectOrdering)::keytype(a)
296296
require_one_based_indexing(a)
297297
if lt(o, first(a), x)
298298
if step(a) == 0
@@ -305,7 +305,7 @@ function searchsortedfirst(a::AbstractRange{<:Integer}, x::Unsigned, o::DirectOr
305305
end
306306
end
307307

308-
function searchsortedlast(a::AbstractRange{<:Integer}, x::Unsigned, o::DirectOrdering)
308+
function searchsortedlast(a::AbstractRange{<:Integer}, x::Unsigned, o::DirectOrdering)::keytype(a)
309309
require_one_based_indexing(a)
310310
if lt(o, x, first(a))
311311
0

base/strings/search.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,7 @@ findfirst(ch::AbstractChar, string::AbstractString) = findfirst(==(ch), string)
125125

126126
# AbstractString implementation of the generic findnext interface
127127
function findnext(testf::Function, s::AbstractString, i::Integer)
128+
i = Int(i)
128129
z = ncodeunits(s) + 1
129130
1 i  z || throw(BoundsError(s, i))
130131
@inbounds i == z || isvalid(s, i) || string_index_err(s, i)
@@ -272,7 +273,7 @@ julia> findnext("Lang", "JuliaLang", 2)
272273
6:9
273274
```
274275
"""
275-
findnext(t::AbstractString, s::AbstractString, i::Integer) = _search(s, t, i)
276+
findnext(t::AbstractString, s::AbstractString, i::Integer) = _search(s, t, Int(i))
276277

277278
"""
278279
findnext(ch::AbstractChar, string::AbstractString, start::Integer)
@@ -484,7 +485,7 @@ julia> findprev("Julia", "JuliaLang", 6)
484485
1:5
485486
```
486487
"""
487-
findprev(t::AbstractString, s::AbstractString, i::Integer) = _rsearch(s, t, i)
488+
findprev(t::AbstractString, s::AbstractString, i::Integer) = _rsearch(s, t, Int(i))
488489

489490
"""
490491
findprev(ch::AbstractChar, string::AbstractString, start::Integer)

stdlib/SparseArrays/test/sparse.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2468,6 +2468,22 @@ end
24682468
@test findnext(!iszero, z,i) == findnext(!iszero, z_sp,i)
24692469
@test findprev(!iszero, z,i) == findprev(!iszero, z_sp,i)
24702470
end
2471+
2472+
# issue 32568
2473+
for T = (UInt, BigInt)
2474+
@test findnext(!iszero, x_sp, T(4)) isa keytype(x_sp)
2475+
@test findnext(!iszero, x_sp, T(5)) isa keytype(x_sp)
2476+
@test findprev(!iszero, x_sp, T(5)) isa keytype(x_sp)
2477+
@test findprev(!iszero, x_sp, T(6)) isa keytype(x_sp)
2478+
@test findnext(iseven, x_sp, T(4)) isa keytype(x_sp)
2479+
@test findnext(iseven, x_sp, T(5)) isa keytype(x_sp)
2480+
@test findprev(iseven, x_sp, T(4)) isa keytype(x_sp)
2481+
@test findprev(iseven, x_sp, T(5)) isa keytype(x_sp)
2482+
@test findnext(!iszero, z_sp, T(4)) isa keytype(z_sp)
2483+
@test findnext(!iszero, z_sp, T(5)) isa keytype(z_sp)
2484+
@test findprev(!iszero, z_sp, T(4)) isa keytype(z_sp)
2485+
@test findprev(!iszero, z_sp, T(5)) isa keytype(z_sp)
2486+
end
24712487
end
24722488

24732489
# #20711

test/arrayops.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -579,6 +579,18 @@ end
579579
@test findlast(isequal(0x00), [0x01, 0x00]) == 2
580580
@test findnext(isequal(0x00), [0x00, 0x01, 0x00], 2) == 3
581581
@test findprev(isequal(0x00), [0x00, 0x01, 0x00], 2) == 1
582+
583+
@testset "issue 32568" for T = (UInt, BigInt)
584+
@test findnext(!iszero, a, T(1)) isa keytype(a)
585+
@test findnext(!iszero, a, T(2)) isa keytype(a)
586+
@test findprev(!iszero, a, T(4)) isa keytype(a)
587+
@test findprev(!iszero, a, T(5)) isa keytype(a)
588+
b = [true, false, true]
589+
@test findnext(b, T(2)) isa keytype(b)
590+
@test findnext(b, T(3)) isa keytype(b)
591+
@test findprev(b, T(1)) isa keytype(b)
592+
@test findprev(b, T(2)) isa keytype(b)
593+
end
582594
end
583595
@testset "find with Matrix" begin
584596
A = [1 2 0; 3 4 0]

test/bitarray.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1307,6 +1307,21 @@ timesofar("find")
13071307
@test_throws BoundsError findprev(x->true, b1, 11)
13081308
@test_throws BoundsError findnext(x->true, b1, -1)
13091309

1310+
@testset "issue 32568" for T = (UInt, BigInt)
1311+
for x = (1, 2)
1312+
@test findnext(evens, T(x)) isa keytype(evens)
1313+
@test findnext(iseven, evens, T(x)) isa keytype(evens)
1314+
@test findnext(isequal(true), evens, T(x)) isa keytype(evens)
1315+
@test findnext(isequal(false), evens, T(x)) isa keytype(evens)
1316+
end
1317+
for x = (3, 4)
1318+
@test findprev(evens, T(x)) isa keytype(evens)
1319+
@test findprev(iseven, evens, T(x)) isa keytype(evens)
1320+
@test findprev(isequal(true), evens, T(x)) isa keytype(evens)
1321+
@test findprev(isequal(false), evens, T(x)) isa keytype(evens)
1322+
end
1323+
end
1324+
13101325
for l = [1, 63, 64, 65, 127, 128, 129]
13111326
f = falses(l)
13121327
t = trues(l)

test/sorting.jl

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,20 @@ end
141141
@test searchsortedlast(500:1.0:600, -1.0e20) == 0
142142
@test searchsortedlast(500:1.0:600, 1.0e20) == 101
143143
end
144+
145+
@testset "issue 32568" begin
146+
for R in numTypes, T in numTypes
147+
for arr in [R[1:5;], R(1):R(5), R(1):2:R(5)]
148+
@test eltype(searchsorted(arr, T(2))) == keytype(arr)
149+
@test eltype(searchsorted(arr, T(2), big(1), big(4), Forward)) == keytype(arr)
150+
@test searchsortedfirst(arr, T(2)) isa keytype(arr)
151+
@test searchsortedfirst(arr, T(2), big(1), big(4), Forward) isa keytype(arr)
152+
@test searchsortedlast(arr, T(2)) isa keytype(arr)
153+
@test searchsortedlast(arr, T(2), big(1), big(4), Forward) isa keytype(arr)
154+
end
155+
end
156+
end
157+
144158
@testset "issue #34157" begin
145159
@test searchsorted(1:2.0, -Inf) === 1:0
146160
@test searchsorted([1,2], -Inf) === 1:0
@@ -173,7 +187,6 @@ end
173187
@test searchsortedlast(reverse(coll), -huge, rev=true) === lastindex(coll)
174188
@test searchsorted(reverse(coll), huge, rev=true) === firstindex(coll):firstindex(coll) - 1
175189
@test searchsorted(reverse(coll), -huge, rev=true) === lastindex(coll)+1:lastindex(coll)
176-
177190
end
178191
end
179192
end

test/strings/search.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -389,3 +389,22 @@ s_18109 = "fooα🐨βcd3"
389389
@test findall("aa", "aaaaaa") == [1:2, 3:4, 5:6]
390390
@test findall("aa", "aaaaaa", overlap=true) == [1:2, 2:3, 3:4, 4:5, 5:6]
391391
end
392+
393+
# issue 32568
394+
for T = (UInt, BigInt)
395+
for x = (4, 5)
396+
@test eltype(findnext(r"l", astr, T(x))) == Int
397+
@test findnext(isequal('l'), astr, T(x)) isa Int
398+
@test findprev(isequal('l'), astr, T(x)) isa Int
399+
@test findnext('l', astr, T(x)) isa Int
400+
@test findprev('l', astr, T(x)) isa Int
401+
end
402+
for x = (5, 6)
403+
@test eltype(findprev(",b", "foo,bar,baz", T(x))) == Int
404+
end
405+
for x = (7, 8)
406+
@test eltype(findnext(",b", "foo,bar,baz", T(x))) == Int
407+
@test findnext(isletter, astr, T(x)) isa Int
408+
@test findprev(isletter, astr, T(x)) isa Int
409+
end
410+
end

test/tuple.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -479,6 +479,13 @@ end
479479
@test findprev(isequal(1), (1, 1), 1) == 1
480480
@test findnext(isequal(1), (2, 3), 1) === nothing
481481
@test findprev(isequal(1), (2, 3), 2) === nothing
482+
483+
@testset "issue 32568" begin
484+
@test findnext(isequal(1), (1, 2), big(1)) isa Int
485+
@test findprev(isequal(1), (1, 2), big(2)) isa Int
486+
@test findnext(isequal(1), (1, 1), UInt(2)) isa Int
487+
@test findprev(isequal(1), (1, 1), UInt(1)) isa Int
488+
end
482489
end
483490

484491
@testset "properties" begin

0 commit comments

Comments
 (0)