Skip to content

Commit 6fb3558

Browse files
authored
Range indexing: error with scalar bool index like all other arrays (#31829)
1 parent 5fab42a commit 6fb3558

File tree

4 files changed

+289
-33
lines changed

4 files changed

+289
-33
lines changed

NEWS.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,9 @@ Standard library changes
5151
* `escape_string` can now receive a collection of characters in the keyword
5252
`keep` that are to be kept as they are. ([#38597]).
5353
* `getindex` can now be used on `NamedTuple`s with multiple values ([#38878])
54+
* Subtypes of `AbstractRange` now correctly follow the general array indexing
55+
behavior when indexed by `Bool`s, erroring for scalar `Bool`s and treating
56+
arrays (including ranges) of `Bool` as an logical index ([#31829])
5457
* `keys(::RegexMatch)` is now defined to return the capture's keys, by name if named, or by index if not ([#37299]).
5558
* `keys(::Generator)` is now defined to return the iterator's keys ([#34678])
5659
* `RegexMatch` now iterate to give their captures. ([#34355]).

base/range.jl

Lines changed: 111 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -392,12 +392,19 @@ be 1.
392392
"""
393393
struct OneTo{T<:Integer} <: AbstractUnitRange{T}
394394
stop::T
395-
OneTo{T}(stop) where {T<:Integer} = new(max(zero(T), stop))
395+
function OneTo{T}(stop) where {T<:Integer}
396+
throwbool(r) = (@_noinline_meta; throw(ArgumentError("invalid index: $r of type Bool")))
397+
T === Bool && throwbool(stop)
398+
return new(max(zero(T), stop))
399+
end
400+
396401
function OneTo{T}(r::AbstractRange) where {T<:Integer}
397402
throwstart(r) = (@_noinline_meta; throw(ArgumentError("first element must be 1, got $(first(r))")))
398403
throwstep(r) = (@_noinline_meta; throw(ArgumentError("step must be 1, got $(step(r))")))
404+
throwbool(r) = (@_noinline_meta; throw(ArgumentError("invalid index: $r of type Bool")))
399405
first(r) == 1 || throwstart(r)
400406
step(r) == 1 || throwstep(r)
407+
T === Bool && throwbool(r)
401408
return new(max(zero(T), last(r)))
402409
end
403410
end
@@ -748,6 +755,7 @@ _in_unit_range(v::UnitRange, val, i::Integer) = i > 0 && val <= v.stop && val >=
748755

749756
function getindex(v::UnitRange{T}, i::Integer) where T
750757
@_inline_meta
758+
i isa Bool && throw(ArgumentError("invalid index: $i of type Bool"))
751759
val = convert(T, v.start + (i - 1))
752760
@boundscheck _in_unit_range(v, val, i) || throw_boundserror(v, i)
753761
val
@@ -758,19 +766,22 @@ const OverflowSafe = Union{Bool,Int8,Int16,Int32,Int64,Int128,
758766

759767
function getindex(v::UnitRange{T}, i::Integer) where {T<:OverflowSafe}
760768
@_inline_meta
769+
i isa Bool && throw(ArgumentError("invalid index: $i of type Bool"))
761770
val = v.start + (i - 1)
762771
@boundscheck _in_unit_range(v, val, i) || throw_boundserror(v, i)
763772
val % T
764773
end
765774

766775
function getindex(v::OneTo{T}, i::Integer) where T
767776
@_inline_meta
777+
i isa Bool && throw(ArgumentError("invalid index: $i of type Bool"))
768778
@boundscheck ((i > 0) & (i <= v.stop)) || throw_boundserror(v, i)
769779
convert(T, i)
770780
end
771781

772782
function getindex(v::AbstractRange{T}, i::Integer) where T
773783
@_inline_meta
784+
i isa Bool && throw(ArgumentError("invalid index: $i of type Bool"))
774785
ret = convert(T, first(v) + (i - 1)*step_hp(v))
775786
ok = ifelse(step(v) > zero(step(v)),
776787
(ret <= last(v)) & (ret >= first(v)),
@@ -781,22 +792,26 @@ end
781792

782793
function getindex(r::Union{StepRangeLen,LinRange}, i::Integer)
783794
@_inline_meta
795+
i isa Bool && throw(ArgumentError("invalid index: $i of type Bool"))
784796
@boundscheck checkbounds(r, i)
785797
unsafe_getindex(r, i)
786798
end
787799

788800
# This is separate to make it useful even when running with --check-bounds=yes
789801
function unsafe_getindex(r::StepRangeLen{T}, i::Integer) where T
802+
i isa Bool && throw(ArgumentError("invalid index: $i of type Bool"))
790803
u = i - r.offset
791804
T(r.ref + u*r.step)
792805
end
793806

794807
function _getindex_hiprec(r::StepRangeLen, i::Integer) # without rounding by T
808+
i isa Bool && throw(ArgumentError("invalid index: $i of type Bool"))
795809
u = i - r.offset
796810
r.ref + u*r.step
797811
end
798812

799813
function unsafe_getindex(r::LinRange, i::Integer)
814+
i isa Bool && throw(ArgumentError("invalid index: $i of type Bool"))
800815
lerpi(i-1, r.lendiv, r.start, r.stop)
801816
end
802817

@@ -808,12 +823,27 @@ end
808823

809824
getindex(r::AbstractRange, ::Colon) = copy(r)
810825

811-
function getindex(r::AbstractUnitRange, s::AbstractUnitRange{<:Integer})
826+
function getindex(r::AbstractUnitRange, s::AbstractUnitRange{T}) where {T<:Integer}
812827
@_inline_meta
813828
@boundscheck checkbounds(r, s)
814-
f = first(r)
815-
st = oftype(f, f + first(s)-1)
816-
range(st, length=length(s))
829+
830+
if T === Bool
831+
if length(s) == 0
832+
return r
833+
elseif length(s) == 1
834+
if first(s)
835+
return r
836+
else
837+
return range(r[1], length=0)
838+
end
839+
else # length(s) == 2
840+
return range(r[2], length=1)
841+
end
842+
else
843+
f = first(r)
844+
st = oftype(f, f + first(s)-1)
845+
return range(st, length=length(s))
846+
end
817847
end
818848

819849
function getindex(r::OneTo{T}, s::OneTo) where T
@@ -822,36 +852,96 @@ function getindex(r::OneTo{T}, s::OneTo) where T
822852
OneTo(T(s.stop))
823853
end
824854

825-
function getindex(r::AbstractUnitRange, s::StepRange{<:Integer})
855+
function getindex(r::AbstractUnitRange, s::StepRange{T}) where {T<:Integer}
826856
@_inline_meta
827857
@boundscheck checkbounds(r, s)
828-
st = oftype(first(r), first(r) + s.start-1)
829-
range(st, step=step(s), length=length(s))
858+
859+
if T === Bool
860+
if length(s) == 0
861+
return range(first(r), step=one(eltype(r)), length=0)
862+
elseif length(s) == 1
863+
if first(s)
864+
return range(first(r), step=one(eltype(r)), length=1)
865+
else
866+
return range(first(r), step=one(eltype(r)), length=0)
867+
end
868+
else # length(s) == 2
869+
return range(r[2], step=one(eltype(r)), length=1)
870+
end
871+
else
872+
st = oftype(first(r), first(r) + s.start-1)
873+
return range(st, step=step(s), length=length(s))
874+
end
830875
end
831876

832-
function getindex(r::StepRange, s::AbstractRange{<:Integer})
877+
function getindex(r::StepRange, s::AbstractRange{T}) where {T<:Integer}
833878
@_inline_meta
834879
@boundscheck checkbounds(r, s)
835-
st = oftype(r.start, r.start + (first(s)-1)*step(r))
836-
range(st, step=step(r)*step(s), length=length(s))
880+
881+
if T === Bool
882+
if length(s) == 0
883+
return range(first(r), step=step(r), length=0)
884+
elseif length(s) == 1
885+
if first(s)
886+
return range(first(r), step=step(r), length=1)
887+
else
888+
return range(first(r), step=step(r), length=0)
889+
end
890+
else # length(s) == 2
891+
return range(r[2], step=step(r), length=1)
892+
end
893+
else
894+
st = oftype(r.start, r.start + (first(s)-1)*step(r))
895+
return range(st, step=step(r)*step(s), length=length(s))
896+
end
837897
end
838898

839-
function getindex(r::StepRangeLen{T}, s::OrdinalRange{<:Integer}) where {T}
899+
function getindex(r::StepRangeLen{T}, s::OrdinalRange{S}) where {T, S<:Integer}
840900
@_inline_meta
841901
@boundscheck checkbounds(r, s)
842-
# Find closest approach to offset by s
843-
ind = LinearIndices(s)
844-
offset = max(min(1 + round(Int, (r.offset - first(s))/step(s)), last(ind)), first(ind))
845-
ref = _getindex_hiprec(r, first(s) + (offset-1)*step(s))
846-
return StepRangeLen{T}(ref, r.step*step(s), length(s), offset)
902+
903+
if S === Bool
904+
if length(s) == 0
905+
return StepRangeLen{T}(first(r), step(r), 0, 1)
906+
elseif length(s) == 1
907+
if first(s)
908+
return StepRangeLen{T}(first(r), step(r), 1, 1)
909+
else
910+
return StepRangeLen{T}(first(r), step(r), 0, 1)
911+
end
912+
else # length(s) == 2
913+
return StepRangeLen{T}(r[2], step(r), 1, 1)
914+
end
915+
else
916+
# Find closest approach to offset by s
917+
ind = LinearIndices(s)
918+
offset = max(min(1 + round(Int, (r.offset - first(s))/step(s)), last(ind)), first(ind))
919+
ref = _getindex_hiprec(r, first(s) + (offset-1)*step(s))
920+
return StepRangeLen{T}(ref, r.step*step(s), length(s), offset)
921+
end
847922
end
848923

849-
function getindex(r::LinRange{T}, s::OrdinalRange{<:Integer}) where {T}
924+
function getindex(r::LinRange{T}, s::OrdinalRange{S}) where {T, S<:Integer}
850925
@_inline_meta
851926
@boundscheck checkbounds(r, s)
852-
vfirst = unsafe_getindex(r, first(s))
853-
vlast = unsafe_getindex(r, last(s))
854-
return LinRange{T}(vfirst, vlast, length(s))
927+
928+
if S === Bool
929+
if length(s) == 0
930+
return LinRange(first(r), first(r), 0)
931+
elseif length(s) == 1
932+
if first(s)
933+
return LinRange(first(r), first(r), 1)
934+
else
935+
return LinRange(first(r), first(r), 0)
936+
end
937+
else # length(s) == 2
938+
return LinRange(r[2], r[2], 1)
939+
end
940+
else
941+
vfirst = unsafe_getindex(r, first(s))
942+
vlast = unsafe_getindex(r, last(s))
943+
return LinRange{T}(vfirst, vlast, length(s))
944+
end
855945
end
856946

857947
show(io::IO, r::AbstractRange) = print(io, repr(first(r)), ':', repr(step(r)), ':', repr(last(r)))

base/twiceprecision.jl

Lines changed: 28 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -448,34 +448,50 @@ end
448448
function unsafe_getindex(r::StepRangeLen{T,<:TwicePrecision,<:TwicePrecision}, i::Integer) where T
449449
# Very similar to _getindex_hiprec, but optimized to avoid a 2nd call to add12
450450
@_inline_meta
451+
i isa Bool && throw(ArgumentError("invalid index: $i of type Bool"))
451452
u = i - r.offset
452453
shift_hi, shift_lo = u*r.step.hi, u*r.step.lo
453454
x_hi, x_lo = add12(r.ref.hi, shift_hi)
454455
T(x_hi + (x_lo + (shift_lo + r.ref.lo)))
455456
end
456457

457458
function _getindex_hiprec(r::StepRangeLen{<:Any,<:TwicePrecision,<:TwicePrecision}, i::Integer)
459+
i isa Bool && throw(ArgumentError("invalid index: $i of type Bool"))
458460
u = i - r.offset
459461
shift_hi, shift_lo = u*r.step.hi, u*r.step.lo
460462
x_hi, x_lo = add12(r.ref.hi, shift_hi)
461463
x_hi, x_lo = add12(x_hi, x_lo + (shift_lo + r.ref.lo))
462464
TwicePrecision(x_hi, x_lo)
463465
end
464466

465-
function getindex(r::StepRangeLen{T,<:TwicePrecision,<:TwicePrecision}, s::OrdinalRange{<:Integer}) where T
467+
function getindex(r::StepRangeLen{T,<:TwicePrecision,<:TwicePrecision}, s::OrdinalRange{S}) where {T, S<:Integer}
466468
@boundscheck checkbounds(r, s)
467-
soffset = 1 + round(Int, (r.offset - first(s))/step(s))
468-
soffset = clamp(soffset, 1, length(s))
469-
ioffset = first(s) + (soffset-1)*step(s)
470-
if step(s) == 1 || length(s) < 2
471-
newstep = r.step
472-
else
473-
newstep = twiceprecision(r.step*step(s), nbitslen(T, length(s), soffset))
474-
end
475-
if ioffset == r.offset
476-
StepRangeLen(r.ref, newstep, length(s), max(1,soffset))
469+
if S === Bool
470+
if length(s) == 0
471+
return StepRangeLen(r.ref, r.step, 0, 1)
472+
elseif length(s) == 1
473+
if first(s)
474+
return StepRangeLen(r.ref, r.step, 1, 1)
475+
else
476+
return StepRangeLen(r.ref, r.step, 0, 1)
477+
end
478+
else # length(s) == 2
479+
return StepRangeLen(r[2], step(r), 1, 1)
480+
end
477481
else
478-
StepRangeLen(r.ref + (ioffset-r.offset)*r.step, newstep, length(s), max(1,soffset))
482+
soffset = 1 + round(Int, (r.offset - first(s))/step(s))
483+
soffset = clamp(soffset, 1, length(s))
484+
ioffset = first(s) + (soffset-1)*step(s)
485+
if step(s) == 1 || length(s) < 2
486+
newstep = r.step
487+
else
488+
newstep = twiceprecision(r.step*step(s), nbitslen(T, length(s), soffset))
489+
end
490+
if ioffset == r.offset
491+
return StepRangeLen(r.ref, newstep, length(s), max(1,soffset))
492+
else
493+
return StepRangeLen(r.ref + (ioffset-r.offset)*r.step, newstep, length(s), max(1,soffset))
494+
end
479495
end
480496
end
481497

0 commit comments

Comments
 (0)