Skip to content

Commit 71e8fa1

Browse files
committed
Introduce PermUnstable to speed up sortperm
1 parent 32a6f54 commit 71e8fa1

File tree

2 files changed

+35
-10
lines changed

2 files changed

+35
-10
lines changed

base/ordering.jl

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ import .Base:
1313

1414
export # not exported by Base
1515
Ordering, Forward, Reverse,
16-
By, Lt, Perm,
16+
By, Lt, Perm, PermUnstable,
1717
ReverseOrdering, ForwardOrdering,
1818
DirectOrdering,
1919
lt, ord, ordtype
@@ -106,8 +106,21 @@ struct Perm{O<:Ordering,V<:AbstractVector} <: Ordering
106106
data::V
107107
end
108108

109+
"""
110+
PermUnstable(order::Ordering, data::AbstractVector)
111+
`Ordering` on the indices of `data` where `i` is less than `j` if `data[i]` is
112+
less than `data[j]` according to `order`. In the case that `data[i]` and
113+
`data[j]` are equal, the ordering is undefined. Thus, it is designed to be
114+
faster than `Perm` when a stable sorting algorithm is used.
115+
"""
116+
struct PermUnstable{O<:Ordering,V<:AbstractVector} <: Ordering
117+
order::O
118+
data::V
119+
end
120+
109121
ReverseOrdering(by::By) = By(by.by, ReverseOrdering(by.order))
110122
ReverseOrdering(perm::Perm) = Perm(ReverseOrdering(perm.order), perm.data)
123+
ReverseOrdering(perm::PermUnstable) = PermUnstable(ReverseOrdering(perm.order), perm.data)
111124

112125
"""
113126
lt(o::Ordering, a, b)
@@ -125,6 +138,12 @@ lt(o::Lt, a, b) = o.lt(a,b)
125138
(lt(p.order, da, db)::Bool) | (!(lt(p.order, db, da)::Bool) & (a < b))
126139
end
127140

141+
@propagate_inbounds function lt(p::PermUnstable, a::Integer, b::Integer)
142+
da = p.data[a]
143+
db = p.data[b]
144+
lt(p.order, da, db)::Bool
145+
end
146+
128147
_ord(lt::typeof(isless), by::typeof(identity), order::Ordering) = order
129148
_ord(lt::typeof(isless), by, order::Ordering) = By(by, order)
130149

base/sort.jl

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -573,9 +573,11 @@ function _sort!(v::AbstractVector, a::MissingOptimization, o::Ordering, kw)
573573
if nonmissingtype(eltype(v)) != eltype(v) && o isa DirectOrdering
574574
lo, hi = send_to_end!(ismissing, v, o; lo, hi)
575575
_sort!(WithoutMissingVector(v, unsafe=true), a.next, o, (;kw..., lo, hi))
576-
elseif eltype(v) <: Integer && o isa Perm{DirectOrdering} && nonmissingtype(eltype(o.data)) != eltype(o.data)
576+
elseif eltype(v) <: Integer && nonmissingtype(eltype(o.data)) != eltype(o.data) &&
577+
(o isa Perm{DirectOrdering} || o isa PermUnstable{DirectOrdering})
578+
PermT = o isa Perm{DirectOrdering} ? Perm : PermUnstable
577579
lo, hi = send_to_end!(i -> ismissing(@inbounds o.data[i]), v, o)
578-
_sort!(v, a.next, Perm(o.order, WithoutMissingVector(o.data, unsafe=true)), (;kw..., lo, hi))
580+
_sort!(v, a.next, PermT(o.order, WithoutMissingVector(o.data, unsafe=true)), (;kw..., lo, hi))
579581
else
580582
_sort!(v, a.next, o, kw)
581583
end
@@ -614,15 +616,17 @@ function _sort!(v::AbstractVector, a::IEEEFloatOptimization, o::Ordering, kw)
614616
else
615617
_sort!(iv, a.next, Forward, (;kw..., lo=j+1, hi, scratch))
616618
end
617-
elseif eltype(v) <: Integer && o isa Perm && o.order isa DirectOrdering && is_concrete_IEEEFloat(eltype(o.data))
619+
elseif eltype(v) <: Integer && (o isa Perm || o isa PermUnstable) &&
620+
o.order isa DirectOrdering && is_concrete_IEEEFloat(eltype(o.data))
618621
lo, hi = send_to_end!(i -> isnan(@inbounds o.data[i]), v, o.order, true; lo, hi)
619622
ip = reinterpret(UIntType(eltype(o.data)), o.data)
620623
j = send_to_end!(i -> after_zero(o.order, @inbounds o.data[i]), v; lo, hi)
621-
scratch = _sort!(v, a.next, Perm(Reverse, ip), (;kw..., lo, hi=j))
624+
PermT = o isa Perm ? Perm : PermUnstable
625+
scratch = _sort!(v, a.next, PermT(Reverse, ip), (;kw..., lo, hi=j))
622626
if scratch === nothing # Union split
623-
_sort!(v, a.next, Perm(Forward, ip), (;kw..., lo=j+1, hi, scratch))
627+
_sort!(v, a.next, PermT(Forward, ip), (;kw..., lo=j+1, hi, scratch))
624628
else
625-
_sort!(v, a.next, Perm(Forward, ip), (;kw..., lo=j+1, hi, scratch))
629+
_sort!(v, a.next, PermT(Forward, ip), (;kw..., lo=j+1, hi, scratch))
626630
end
627631
else
628632
_sort!(v, a.next, o, kw)
@@ -1482,7 +1486,7 @@ function partialsortperm!(ix::AbstractVector{<:Integer}, v::AbstractVector,
14821486
end
14831487

14841488
# do partial quicksort
1485-
_sort!(ix, _PartialQuickSort(k), Perm(ord(lt, by, rev, order), v), (;))
1489+
_sort!(ix, _PartialQuickSort(k), PermUnstable(ord(lt, by, rev, order), v), (;))
14861490

14871491
maybeview(ix, k)
14881492
end
@@ -1554,7 +1558,8 @@ function sortperm(A::AbstractArray;
15541558
end
15551559
end
15561560
ix = copymutable(LinearIndices(A))
1557-
sort!(ix; alg, order = Perm(ordr, vec(A)), scratch, dims...)
1561+
PermT = alg == DEFAULT_STABLE ? PermUnstable : Perm
1562+
sort!(ix; alg, order = PermT(ordr, vec(A)), scratch, dims...)
15581563
end
15591564

15601565

@@ -1608,7 +1613,8 @@ function sortperm!(ix::AbstractArray{T}, A::AbstractArray;
16081613
if !initialized
16091614
ix .= LinearIndices(A)
16101615
end
1611-
sort!(ix; alg, order = Perm(ord(lt, by, rev, order), vec(A)), scratch, dims...)
1616+
PermT = alg == DEFAULT_STABLE ? PermUnstable : Perm
1617+
sort!(ix; alg, order = PermT(ord(lt, by, rev, order), vec(A)), scratch, dims...)
16121618
end
16131619

16141620
# sortperm for vectors of few unique integers

0 commit comments

Comments
 (0)