Skip to content

Commit 897c8e6

Browse files
authored
Add intersect(::AbstractRange, ::AbstractVector) (#41769)
1 parent d241297 commit 897c8e6

File tree

7 files changed

+80
-12
lines changed

7 files changed

+80
-12
lines changed

NEWS.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ Standard library changes
5050
arithmetic to error if the result may be wrapping. Or use a package such as SaferIntegers.jl when
5151
constructing the range. ([#40382])
5252
* TCP socket objects now expose `closewrite` functionality and support half-open mode usage ([#40783]).
53+
* Intersect returns a result with the eltype of the type-promoted eltypes of the two inputs ([#41769]).
5354

5455
#### InteractiveUtils
5556
* A new macro `@time_imports` for reporting any time spent importing packages and their dependencies ([#41612])

base/abstractset.jl

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,6 @@ Set{Int64} with 3 elements:
4545
"""
4646
function union end
4747

48-
_in(itr) = x -> x in itr
49-
5048
union(s, sets...) = union!(emptymutable(s, promote_eltype(s, sets...)), s, sets...)
5149
union(s::AbstractSet) = copy(s)
5250

@@ -109,6 +107,10 @@ Maintain order with arrays.
109107
110108
See also: [`setdiff`](@ref), [`isdisjoint`](@ref), [`issubset`](@ref Base.issubset), [`issetequal`](@ref).
111109
110+
!!! compat "Julia 1.8"
111+
As of Julia 1.8 intersect returns a result with the eltype of the
112+
type-promoted eltypes of the two inputs
113+
112114
# Examples
113115
```jldoctest
114116
julia> intersect([1, 2, 3], [3, 4, 5])
@@ -125,9 +127,12 @@ Set{Int64} with 1 element:
125127
2
126128
```
127129
"""
128-
intersect(s::AbstractSet, itr, itrs...) = intersect!(intersect(s, itr), itrs...)
130+
function intersect(s::AbstractSet, itr, itrs...)
131+
T = promote_eltype(s, itr, itrs...)
132+
return intersect!(Set{T}(s), itr, itrs...)
133+
end
129134
intersect(s) = union(s)
130-
intersect(s::AbstractSet, itr) = mapfilter(_in(s), push!, itr, emptymutable(s))
135+
intersect(s::AbstractSet, itr) = mapfilter(in(s), push!, itr, emptymutable(s, promote_eltype(s, itr)))
131136

132137
const = intersect
133138

@@ -143,7 +148,7 @@ function intersect!(s::AbstractSet, itrs...)
143148
end
144149
return s
145150
end
146-
intersect!(s::AbstractSet, s2::AbstractSet) = filter!(_in(s2), s)
151+
intersect!(s::AbstractSet, s2::AbstractSet) = filter!(in(s2), s)
147152
intersect!(s::AbstractSet, itr) =
148153
intersect!(s, union!(emptymutable(s, eltype(itr)), itr))
149154

base/array.jl

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2603,19 +2603,27 @@ function _shrink!(shrinker!, v::AbstractVector, itrs)
26032603
seen = Set{eltype(v)}()
26042604
filter!(_grow_filter!(seen), v)
26052605
shrinker!(seen, itrs...)
2606-
filter!(_in(seen), v)
2606+
filter!(in(seen), v)
26072607
end
26082608

26092609
intersect!(v::AbstractVector, itrs...) = _shrink!(intersect!, v, itrs)
26102610
setdiff!( v::AbstractVector, itrs...) = _shrink!(setdiff!, v, itrs)
26112611

2612-
vectorfilter(f, v::AbstractVector) = filter(f, v) # TODO: do we want this special case?
2613-
vectorfilter(f, v) = [x for x in v if f(x)]
2612+
vectorfilter(T::Type, f, v) = T[x for x in v if f(x)]
26142613

26152614
function _shrink(shrinker!, itr, itrs)
2616-
keep = shrinker!(Set(itr), itrs...)
2617-
vectorfilter(_shrink_filter!(keep), itr)
2615+
T = promote_eltype(itr, itrs...)
2616+
keep = shrinker!(Set{T}(itr), itrs...)
2617+
vectorfilter(T, _shrink_filter!(keep), itr)
26182618
end
26192619

26202620
intersect(itr, itrs...) = _shrink(intersect!, itr, itrs)
26212621
setdiff( itr, itrs...) = _shrink(setdiff!, itr, itrs)
2622+
2623+
function intersect(v::AbstractVector, r::AbstractRange)
2624+
T = promote_eltype(v, r)
2625+
common = Iterators.filter(in(r), v)
2626+
seen = Set{T}(common)
2627+
return vectorfilter(T, _shrink_filter!(seen), common)
2628+
end
2629+
intersect(r::AbstractRange, v::AbstractVector) = intersect(v, r)

base/range.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1181,6 +1181,16 @@ function intersect(r::StepRange, s::StepRange)
11811181
step(r) < zero(step(r)) ? StepRange{T,S}(n, -a, m) : StepRange{T,S}(m, a, n)
11821182
end
11831183

1184+
function intersect(r1::AbstractRange, r2::AbstractRange)
1185+
# To iterate over the shorter range
1186+
length(r1) > length(r2) && return intersect(r2, r1)
1187+
1188+
r1 = unique(r1)
1189+
T = promote_eltype(r1, r2)
1190+
1191+
return T[x for x in r1 if x r2]
1192+
end
1193+
11841194
function intersect(r1::AbstractRange, r2::AbstractRange, r3::AbstractRange, r::AbstractRange...)
11851195
i = intersect(intersect(r1, r2), r3)
11861196
for t in r

test/arrayops.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1106,6 +1106,11 @@ end
11061106
@test isequal(intersect([1,2,3], Float64[]), Float64[])
11071107
@test isequal(intersect(Int64[], [1,2,3]), Int64[])
11081108
@test isequal(intersect(Int64[]), Int64[])
1109+
@test isequal(intersect([1, 3], 1:typemax(Int)), [1, 3])
1110+
@test isequal(intersect(1:typemax(Int), [1, 3]), [1, 3])
1111+
@test isequal(intersect([1, 2, 3], 2:0.1:5), [2., 3.])
1112+
@test isequal(intersect([1.0, 2.0, 3.0], 2:5), [2., 3.])
1113+
11091114
@test isequal(setdiff([1,2,3,4], [2,5,4]), [1,3])
11101115
@test isequal(setdiff([1,2,3,4], [7,8,9]), [1,2,3,4])
11111116
@test isequal(setdiff([1,2,3,4], Int64[]), Int64[1,2,3,4])

test/ranges.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -417,6 +417,9 @@ end
417417
@test intersect(1:3, 2) === intersect(2, 1:3) === 2:2
418418
@test intersect(1.0:3.0, 2) == intersect(2, 1.0:3.0) == [2.0]
419419

420+
@test intersect(1:typemax(Int), [1, 3]) == [1, 3]
421+
@test intersect([1, 3], 1:typemax(Int)) == [1, 3]
422+
420423
@testset "Support StepRange with a non-numeric step" begin
421424
start = Date(1914, 7, 28)
422425
stop = Date(1918, 11, 11)
@@ -426,6 +429,21 @@ end
426429
@test intersect(start-Day(10):Day(1):stop-Day(10), start:Day(5):stop) ==
427430
start:Day(5):stop-Day(10)-mod(stop-start, Day(5))
428431
end
432+
433+
@testset "Two AbstractRanges" begin
434+
struct DummyRange{T} <: AbstractRange{T}
435+
r
436+
end
437+
Base.iterate(dr::DummyRange) = iterate(dr.r)
438+
Base.iterate(dr::DummyRange, state) = iterate(dr.r, state)
439+
Base.length(dr::DummyRange) = length(dr.r)
440+
Base.in(x::Int, dr::DummyRange) = in(x, dr.r)
441+
Base.unique(dr::DummyRange) = unique(dr.r)
442+
r1 = DummyRange{Int}([1, 2, 3, 3, 4, 5])
443+
r2 = DummyRange{Int}([3, 3, 4, 5, 6])
444+
@test intersect(r1, r2) == [3, 4, 5]
445+
@test intersect(r2, r1) == [3, 4, 5]
446+
end
429447
end
430448
@testset "issubset" begin
431449
@test issubset(1:3, 1:typemax(Int)) #32461

test/sets.jl

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,16 @@ end
220220
s2 = Set([nothing])
221221
union!(s2, [nothing])
222222
@test s2 == Set([nothing])
223+
224+
@testset "promotion" begin
225+
ints = [1:5, [1, 2], Set([1, 2])]
226+
floats = [2:0.1:3, [2.0, 3.5], Set([2.0, 3.5])]
227+
228+
for a in ints, b in floats
229+
@test eltype(union(a, b)) == Float64
230+
@test eltype(union(b, a)) == Float64
231+
end
232+
end
223233
end
224234

225235
@testset "intersect" begin
@@ -238,7 +248,7 @@ end
238248
end
239249
end
240250
@test intersect(Set([1]), BitSet()) isa Set{Int}
241-
@test intersect(BitSet([1]), Set()) isa BitSet
251+
@test intersect(BitSet([1]), Set()) isa Set{Any}
242252
@test intersect([1], BitSet()) isa Vector{Int}
243253
# intersect must uniquify
244254
@test intersect([1, 2, 1]) == intersect!([1, 2, 1]) == [1, 2]
@@ -249,7 +259,18 @@ end
249259
y = () (42,)
250260
@test isempty(x)
251261
@test isempty(y)
252-
@test eltype(x) == eltype(y) == Union{}
262+
263+
# Discussed in PR#41769
264+
@testset "promotion" begin
265+
ints = [1:5, [1, 2], Set([1, 2])]
266+
floats = [2:0.1:3, [2.0, 3.5], Set([2.0, 3.5])]
267+
268+
for a in ints, b in floats
269+
@test eltype(intersect(a, b)) == Float64
270+
@test eltype(intersect(b, a)) == Float64
271+
@test eltype(intersect(a, a, b)) == Float64
272+
end
273+
end
253274
end
254275

255276
@testset "setdiff" begin

0 commit comments

Comments
 (0)