diff --git a/base/ordering.jl b/base/ordering.jl index d0c9cb99f9c72..3343a4ae4c325 100644 --- a/base/ordering.jl +++ b/base/ordering.jl @@ -13,7 +13,7 @@ import .Base: export # not exported by Base Ordering, Forward, Reverse, - By, Lt, Perm, + By, Lt, Perm, PermFast, ReverseOrdering, ForwardOrdering, DirectOrdering, lt, ord, ordtype @@ -106,8 +106,21 @@ struct Perm{O<:Ordering,V<:AbstractVector} <: Ordering data::V end +""" + PermFast(order::Ordering, data::AbstractVector) +`Ordering` on the indices of `data` where `i` is less than `j` if `data[i]` is +less than `data[j]` according to `order`. In the case that `data[i]` and +`data[j]` are equal, the ordering is undefined. Thus, it is designed to be +faster than `Perm` when a stable sorting algorithm is used. +""" +struct PermFast{O<:Ordering,V<:AbstractVector} <: Ordering + order::O + data::V +end + ReverseOrdering(by::By) = By(by.by, ReverseOrdering(by.order)) ReverseOrdering(perm::Perm) = Perm(ReverseOrdering(perm.order), perm.data) +ReverseOrdering(perm::PermFast) = PermFast(ReverseOrdering(perm.order), perm.data) """ lt(o::Ordering, a, b) @@ -125,6 +138,12 @@ lt(o::Lt, a, b) = o.lt(a,b) (lt(p.order, da, db)::Bool) | (!(lt(p.order, db, da)::Bool) & (a < b)) end +@propagate_inbounds function lt(p::PermFast, a::Integer, b::Integer) + da = p.data[a] + db = p.data[b] + lt(p.order, da, db)::Bool +end + _ord(lt::typeof(isless), by::typeof(identity), order::Ordering) = order _ord(lt::typeof(isless), by, order::Ordering) = By(by, order) diff --git a/base/sort.jl b/base/sort.jl index 669d2d97b2ac1..d581eeb117aa9 100644 --- a/base/sort.jl +++ b/base/sort.jl @@ -560,13 +560,49 @@ elements that are not @inline send_to_end!(f::F, v::AbstractVector, ::ReverseOrdering, end_stable=false; lo, hi) where F <: Function = end_stable ? (send_to_end!(!f, v; lo, hi)+1, hi) : (hi-send_to_end!(f, view(v, hi:-1:lo))+1, hi) +""" + send_to_end_stable!(f::Function, v::AbstractVector; [lo, hi]) + +Send every element of `v` for which `f` returns `true` to the end of the vector `out` and return +the index of the last element which for which `f` returns `false`. + +`send_to_end_stable!(f, v, out, lo, hi)` is equivalent to `send_to_end_stable!(f, view(v, lo:hi), view(out, lo:hi))+lo-1` + +Preserves the order of the elements. +""" +function send_to_end_stable!(f::F, v::AbstractVector, out::AbstractVector; lo=firstindex(v), hi=lastindex(v)) where F <: Function + offset = 0 + @inbounds begin + while lo <= hi + x = v[lo] + fx = f(x)::Bool + out[(fx ? hi : lo) - offset] = x + offset += fx + lo += 1 + end + end + + # This is similar to the partition function + pivot_index = lo-offset-1 + # out[<=pivot_index] <* f(x) = false + # out[>pivot_index] >* f(x) = true + + # Make the results stable + reverse!(out, pivot_index+1, hi) + return pivot_index +end + +@inline send_to_end_stable!(f::F, v::AbstractVector, out::AbstractVector, ::ForwardOrdering; lo, hi) where F <: Function = + (lo, send_to_end_stable!(f, v, out; lo, hi)) +@inline send_to_end_stable!(f::F, v::AbstractVector, out::AbstractVector, ::ReverseOrdering; lo, hi) where F <: Function = + (hi-send_to_end_stable!(f, view(v, hi:-1:lo), view(out, hi:-1:lo))+1, hi) function _sort!(v::AbstractVector, a::MissingOptimization, o::Ordering, kw) @getkw lo hi if nonmissingtype(eltype(v)) != eltype(v) && o isa DirectOrdering lo, hi = send_to_end!(ismissing, v, o; lo, hi) _sort!(WithoutMissingVector(v, unsafe=true), a.next, o, (;kw..., lo, hi)) - elseif eltype(v) <: Integer && o isa Perm && o.order isa DirectOrdering && + elseif eltype(v) <: Integer && (o isa Perm || o isa PermT) && o.order isa DirectOrdering && nonmissingtype(eltype(o.data)) != eltype(o.data) && all(i === j for (i,j) in zip(v, eachindex(o.data))) # TODO make this branch known at compile time @@ -590,7 +626,8 @@ function _sort!(v::AbstractVector, a::MissingOptimization, o::Ordering, kw) hi = hi_i end - _sort!(v, a.next, Perm(o.order, WithoutMissingVector(o.data, unsafe=true)), (;kw..., lo, hi)) + PermT = o isa PermFast ? PermFast : Perm + _sort!(v, a.next, PermT(o.order, WithoutMissingVector(o.data, unsafe=true)), (;kw..., lo, hi)) else _sort!(v, a.next, o, kw) end @@ -618,26 +655,39 @@ after_zero(::ForwardOrdering, x) = !signbit(x) after_zero(::ReverseOrdering, x) = signbit(x) is_concrete_IEEEFloat(T::Type) = T <: Base.IEEEFloat && isconcretetype(T) function _sort!(v::AbstractVector, a::IEEEFloatOptimization, o::Ordering, kw) - @getkw lo hi + @getkw lo hi scratch if is_concrete_IEEEFloat(eltype(v)) && o isa DirectOrdering lo, hi = send_to_end!(isnan, v, o, true; lo, hi) iv = reinterpret(UIntType(eltype(v)), v) j = send_to_end!(x -> after_zero(o, x), v; lo, hi) scratch = _sort!(iv, a.next, Reverse, (;kw..., lo, hi=j)) if scratch === nothing # Union split - _sort!(iv, a.next, Forward, (;kw..., lo=j+1, hi, scratch)) + _sort!(iv, a.next, Forward, (;kw..., lo=j+1, hi, nothing)) else _sort!(iv, a.next, Forward, (;kw..., lo=j+1, hi, scratch)) end - elseif eltype(v) <: Integer && o isa Perm && o.order isa DirectOrdering && is_concrete_IEEEFloat(eltype(o.data)) - lo, hi = send_to_end!(i -> isnan(@inbounds o.data[i]), v, o.order, true; lo, hi) + elseif eltype(v) <: Integer && (o isa Perm || o isa PermFast) && o.order isa DirectOrdering && is_concrete_IEEEFloat(eltype(o.data)) + if o isa Perm + lo, hi = send_to_end!(i -> isnan(@inbounds o.data[i]), v, o.order, true; lo, hi) + j = send_to_end!(i -> after_zero(o.order, @inbounds o.data[i]), v; lo, hi) + PermT = Perm + kw = (;kw..., lo, hi=j) + else + scratch, t = make_scratch(scratch, eltype(v), hi-lo+1) + lo2, hi2 = send_to_end_stable!(i -> isnan(@inbounds o.data[i]), v, scratch, o.order; lo, hi) + ran = lo < lo2 ? (lo:lo2-1) : (hi2+1:hi) + v[ran] = view(scratch, ran) + lo, hi = lo2, hi2 + j = send_to_end_stable!(i -> after_zero(o.order, @inbounds o.data[i]), scratch, v; lo, hi) + PermT = PermFast + kw = (;kw..., lo, hi=j, scratch=scratch) + end ip = reinterpret(UIntType(eltype(o.data)), o.data) - j = send_to_end!(i -> after_zero(o.order, @inbounds o.data[i]), v; lo, hi) - scratch = _sort!(v, a.next, Perm(Reverse, ip), (;kw..., lo, hi=j)) + scratch = _sort!(v, a.next, Perm(Reverse, ip), kw) if scratch === nothing # Union split - _sort!(v, a.next, Perm(Forward, ip), (;kw..., lo=j+1, hi, scratch)) + _sort!(v, a.next, PermT(Forward, ip), (;kw..., lo=j+1, hi, nothing)) else - _sort!(v, a.next, Perm(Forward, ip), (;kw..., lo=j+1, hi, scratch)) + _sort!(v, a.next, PermT(Forward, ip), (;kw..., lo=j+1, hi, scratch)) end else _sort!(v, a.next, o, kw) @@ -1579,7 +1629,11 @@ function _sortperm(A::AbstractArray; alg, order, scratch, dims...) end end ix = copymutable(LinearIndices(A)) - sort!(ix; alg, order = Perm(order, vec(A)), scratch, dims...) + if alg == DEFAULT_STABLE + sort!(ix; alg, order = PermFast(order, vec(A)), scratch, dims...) + else + sort!(ix; alg, order = Perm(order, vec(A)), scratch, dims...) + end end @@ -1636,11 +1690,11 @@ julia> sortperm!(p, A; dims=2); p if !initialized ix .= LinearIndices(A) end - + PermT = alg == DEFAULT_STABLE ? PermFast : Perm if rev === true - sort!(ix; alg, order=Perm(ord(lt, by, true, order), vec(A)), scratch, dims...) + sort!(ix; alg, order = PermT(ord(lt, by, rev, order), vec(A)), scratch, dims...) else - sort!(ix; alg, order=Perm(ord(lt, by, nothing, order), vec(A)), scratch, dims...) + sort!(ix; alg, order = PermT(ord(lt, by, nothing, order), vec(A)), scratch, dims...) end end