From e29d396593768f186c366160c9c96daad5009777 Mon Sep 17 00:00:00 2001 From: Lilith Hafner Date: Sun, 5 Nov 2023 10:30:22 -0600 Subject: [PATCH 1/5] Add internal order preparation API --- base/ordering.jl | 97 +++++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 88 insertions(+), 9 deletions(-) diff --git a/base/ordering.jl b/base/ordering.jl index 36d8e90064eba..f8e43ef21ec8e 100644 --- a/base/ordering.jl +++ b/base/ordering.jl @@ -114,17 +114,96 @@ ReverseOrdering(perm::Perm) = Perm(ReverseOrdering(perm.order), perm.data) lt(o::Ordering, a, b) Test whether `a` is less than `b` according to the ordering `o`. +""" # No see-also because the prepared ordering system is experimental. +function lt end + """ -lt(o::ForwardOrdering, a, b) = isless(a,b) -lt(o::ReverseOrdering, a, b) = lt(o.fwd,b,a) -lt(o::By, a, b) = lt(o.order,o.by(a),o.by(b)) -lt(o::Lt, a, b) = o.lt(a,b) + lt_prepared(o::Ordering, a, b) -@propagate_inbounds function lt(p::Perm, a::Integer, b::Integer) - da = p.data[a] - db = p.data[b] - (lt(p.order, da, db)::Bool) | (!(lt(p.order, db, da)::Bool) & (a < b)) -end +Test whether `a` is less than `b` according to the ordering `o`, assuming both `a` and `b` +have been prepared with `prepare`. + +`lt_prepared(o, prepare(o, a), prepare(o, b))` is equivalent to `lt(o, a, b)`. + +!!! warning + Comparing a prepared element `prepare(o1, x)` under a different ordering `o2` + is undefined behavior and may, for example, result in segfaults. + +See also `lt_prepared_1`, `lt_prepared_2`. +""" +function lt_prepared end + +""" + lt_prepared_1(o::Ordering, a, b) + +Test whether `a` is less than `b` according to the ordering `o`, assuming `a` has been +prepared with `prepare`. + +`lt_prepared_1(o, prepare(o, a), b)` is equivalent to `lt(o, a, b)`. + +!!! warning + Comparing a prepared element `prepare(o1, x)` under a different ordering `o2` + is undefined behavior and may, for example, result in segfaults. + +See also `lt`, `lt_prepared`. +""" +@propagate_inbounds lt_prepared_1(o::Ordering, a, b) = lt_prepared(o, a, prepare(o, b)) + +""" + lt_prepared_2(o::Ordering, a, b) + +Test whether `a` is less than `b` according to the ordering `o`, assuming `b` has been +prepared with `prepare`. + +!!! warning + Comparing a prepared element `prepare(o1, x)` under a different ordering `o2` + is undefined behavior and may, for example, result in segfaults. + +See also `lt`, `lt_prepared`. +""" +@propagate_inbounds lt_prepared_2(o::Ordering, a, b) = lt_prepared(o, prepare(o, a), b) + +""" + prepare(o::Ordering, x) + +Prepare an element `x` for efficient comparison with `lt_prepared`. + +`lt(o::MyOrdering, a, b)` and `lt_prepared(o, prepare(o, a), prepare(o, b))` are +equivalent. They must have indistinguishable behavior and have the same performance +characteristics. + +If you define `prepare` on a custom `Ordering`, you should also define `lt_prepared` and +should not define `lt` for that order. + +!!! warning + Comparing a prepared element `prepare(o1, x)` under a different ordering `o2` + is undefined behavior and may, for example, result in segfaults. +""" +function prepare end + +# Fallbacks +@propagate_inbounds lt(o::Ordering, a, b) = lt_prepared(o, prepare(o, a), prepare(o, b)) +@propagate_inbounds lt_prepared(o::Ordering, a, b) = lt(o, a, b) # TODO: remove this in Julia 2.0 +prepare(o::Ordering, x) = x + +# Forward +lt(o::ForwardOrdering, a, b) = isless(a, b) + +# Reverse +prepare(o::ReverseOrdering, x) = prepare(o.fwd, x) +lt_prepared(o::ReverseOrdering, a, b) = lt_prepared(o.fwd, b, a) + +# By +prepare(o::By, x) = prepare(o.order, o.by(x)) +lt_prepared(o::By, a, b) = lt_prepared(o.order, a, b) + +# Perm +@propagate_inbounds prepare(o::Perm, i) = (prepare(o.order, o.data[i]), i) +lt_prepared(p::Perm, (da, a), (db, b)) = + (lt_prepared(p.order, da, db)::Bool) | (!(lt_prepared(p.order, db, da)::Bool) & (a < b)) + +## Lt +lt(o::Lt, a, b) = o.lt(a, b) _ord(lt::typeof(isless), by, order::Ordering) = _by(by, order) From 58c747f52ed59578833a437ba6acfb38aaa0eb72 Mon Sep 17 00:00:00 2001 From: Lilith Hafner Date: Sun, 5 Nov 2023 10:30:34 -0600 Subject: [PATCH 2/5] Use the API --- base/sort.jl | 34 +++++++++++++++++++++++----------- 1 file changed, 23 insertions(+), 11 deletions(-) diff --git a/base/sort.jl b/base/sort.jl index a4616a2e65b9e..e23344e4e0d13 100644 --- a/base/sort.jl +++ b/base/sort.jl @@ -4,6 +4,8 @@ module Sort using Base.Order +using Base.Order: prepare, lt_prepared, lt_prepared_1, lt_prepared_2 + using Base: copymutable, midpoint, require_one_based_indexing, uinttype, sub_with_overflow, add_with_overflow, OneTo, BitSigned, BitIntegerType, top_set_bit, IteratorSize, HasShape, IsInfinite, tail @@ -52,11 +54,13 @@ function issorted(itr, order::Ordering) y = iterate(itr) y === nothing && return true prev, state = y + prev_p = prepare(order, prev) y = iterate(itr, state) while y !== nothing this, state = y - lt(order, this, prev) && return false - prev = this + this_p = prepare(order, this) + lt_prepared(order, this_p, prev_p) && return false + prev_p = this_p y = iterate(itr, state) end return true @@ -172,10 +176,11 @@ partialsort(v::AbstractVector, k::Union{Integer,OrdinalRange}; kws...) = function searchsortedfirst(v::AbstractVector, x, lo::T, hi::T, o::Ordering)::keytype(v) where T<:Integer hi = hi + T(1) len = hi - lo + x_p = prepare(o, x) @inbounds while len != 0 half_len = len >>> 0x01 m = lo + half_len - if lt(o, v[m], x) + if lt_prepared_2(o, v[m], x_p) lo = m + 1 len -= half_len + 1 else @@ -192,9 +197,10 @@ function searchsortedlast(v::AbstractVector, x, lo::T, hi::T, o::Ordering)::keyt u = T(1) lo = lo - u hi = hi + u + x_p = prepare(o, x) @inbounds while lo < hi - u m = midpoint(lo, hi) - if lt(o, x, v[m]) + if lt_prepared_1(o, x_p, v[m]) hi = m else lo = m @@ -210,13 +216,15 @@ function searchsorted(v::AbstractVector, x, ilo::T, ihi::T, o::Ordering)::UnitRa u = T(1) lo = ilo - u hi = ihi + u + x_p = prepare(o, x) @inbounds while lo < hi - u m = midpoint(lo, hi) - if lt(o, v[m], x) + if lt_prepared_2(o, v[m], x_p) lo = m - elseif lt(o, x, v[m]) + elseif lt_prepared_1(o, x_p, v[m]) hi = m else + # TODO for further optimization: perform recursive calls with prepared inputs a = searchsortedfirst(v, x, max(lo,ilo), m, o) b = searchsortedlast(v, x, m, min(hi,ihi), o) return a : b @@ -785,9 +793,10 @@ function _sort!(v::AbstractVector, ::InsertionSortAlg, o::Ordering, kw) @inbounds for i = lo_plus_1:hi j = i x = v[i] + x_p = prepare(o, x) while j > lo y = v[j-1] - if !(lt(o, x, y)::Bool) + if !(lt_prepared_1(o, x_p, y)::Bool) break end v[j] = y @@ -1039,16 +1048,17 @@ function partition!(t::AbstractVector, lo::Integer, hi::Integer, offset::Integer pivot_index = mod(hash(lo), lo:hi) @inbounds begin pivot = v[pivot_index] + pivot_p = prepare(o, pivot) while lo < pivot_index x = v[lo] - fx = rev ? !lt(o, x, pivot) : lt(o, pivot, x) + fx = rev ? !lt_prepared_2(o, x, pivot_p) : lt_prepared_1(o, pivot_p, x) t[(fx ? hi : lo) - offset] = x offset += fx lo += 1 end while lo < hi x = v[lo+1] - fx = rev ? lt(o, pivot, x) : !lt(o, x, pivot) + fx = rev ? lt_prepared_1(o, pivot_p, x) : !lt_prepared_2(o, x, pivot_p) t[(fx ? hi : lo) - offset] = x offset += fx lo += 1 @@ -1201,6 +1211,7 @@ end maybe_unsigned(x::Integer) = x # this is necessary to avoid calling unsigned on BigInt maybe_unsigned(x::BitSigned) = unsigned(x) function _issorted(v::AbstractVector, lo::Integer, hi::Integer, o::Ordering) + # TODO: replace this with `issorted(view(v, lo:hi), order=o)` once views are fast. @boundscheck checkbounds(v, lo:hi) @inbounds for i in (lo+1):hi lt(o, v[i], v[i-1]) && return false @@ -2134,12 +2145,13 @@ end function partition!(v::AbstractVector, lo::Integer, hi::Integer, o::Ordering) pivot = selectpivot!(v, lo, hi, o) + pivot_p = prepare(o, pivot) # pivot == v[lo], v[hi] > pivot i, j = lo, hi @inbounds while true i += 1; j -= 1 - while lt(o, v[i], pivot); i += 1; end; - while lt(o, pivot, v[j]); j -= 1; end; + while lt_prepared_2(o, v[i], pivot_p); i += 1; end; + while lt_prepared_1(o, pivot_p, v[j]); j -= 1; end; i >= j && break v[i], v[j] = v[j], v[i] end From 0762aecf6543bd5679501495dc514e76059fb060 Mon Sep 17 00:00:00 2001 From: Lilith Hafner Date: Sun, 5 Nov 2023 10:30:42 -0600 Subject: [PATCH 3/5] Add performance tests --- test/sorting.jl | 44 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 44 insertions(+) diff --git a/test/sorting.jl b/test/sorting.jl index 1164f2932d880..692c5a6bac312 100644 --- a/test/sorting.jl +++ b/test/sorting.jl @@ -1065,6 +1065,50 @@ end @test issorted(sort!(rand(100), Base.Sort.InitialOptimizations(DispatchLoopTestAlg()), Base.Order.Forward)) end +@testset "Performance (how many timems By is called)" begin + # Intentional regressions are acceptable, accedental regressions are not. + cnt = Ref(0) + incr_identity = x -> (cnt[] += 1; x) + x = 1:50 + + cnt[] = 0 + @test issorted(x; by=incr_identity) + @test cnt[] == 50 # Any less would be buggy. + + cnt[] = 0 + @test !issorted(x; by=incr_identity, rev=true) + @test cnt[] == 2 # Any less would be buggy. + + cnt[] = 0 + @test searchsortedfirst(x, 1; by=incr_identity) == 1 + @test cnt[] <= 7 + + cnt[] = 0 + @test searchsorted(repeat(1:10, inner=10), 3; by=incr_identity) == 21:30 + @test cnt[] <= 16 + + cnt[] = 0 + @test sort(x; by=incr_identity) == x + @test cnt[] <= 98 + + cnt[] = 0 + @test sort(1:1000; by=incr_identity) == 1:1000 + @test cnt[] <= 1998 + + cnt[] = 0 + Random.seed!(1729) + x = randperm(1000) + @test sort!(x; by=incr_identity) == 1:1000 + # This should succeed at least 99.99% of the time on random inputs + # and therefore should not be broken by changes to the rng + @test cnt[] <= 17203 + + cnt[] = 0 + x = hash.(1:1000) + @test sort(x; by=incr_identity) == sort(x) + @test cnt[] <= 12999 +end + # This testset is at the end of the file because it is slow. @testset "searchsorted" begin numTypes = [ Int8, Int16, Int32, Int64, Int128, From d010c6810326f867da28e7e694240ae920faebf6 Mon Sep 17 00:00:00 2001 From: Lilith Orion Hafner Date: Thu, 9 Nov 2023 07:06:58 -0600 Subject: [PATCH 4/5] Avoid stack overflow by removing generic `lt` fallback --- base/ordering.jl | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/base/ordering.jl b/base/ordering.jl index f8e43ef21ec8e..bcb78a2dc6b39 100644 --- a/base/ordering.jl +++ b/base/ordering.jl @@ -182,9 +182,10 @@ should not define `lt` for that order. function prepare end # Fallbacks -@propagate_inbounds lt(o::Ordering, a, b) = lt_prepared(o, prepare(o, a), prepare(o, b)) @propagate_inbounds lt_prepared(o::Ordering, a, b) = lt(o, a, b) # TODO: remove this in Julia 2.0 prepare(o::Ordering, x) = x +# Not defining this because it would cause a stack overflow for invalid `Ordering`s: +# lt(o::Ordering, a, b) = lt_prepared(o, prepare(o, a), prepare(o, b)) # Forward lt(o::ForwardOrdering, a, b) = isless(a, b) @@ -192,15 +193,18 @@ lt(o::ForwardOrdering, a, b) = isless(a, b) # Reverse prepare(o::ReverseOrdering, x) = prepare(o.fwd, x) lt_prepared(o::ReverseOrdering, a, b) = lt_prepared(o.fwd, b, a) +lt(::ReverseOrdering, a, b) = lt_prepared(o, prepare(o, a), prepare(o, b)) # By prepare(o::By, x) = prepare(o.order, o.by(x)) lt_prepared(o::By, a, b) = lt_prepared(o.order, a, b) +lt(::By, a, b) = lt_prepared(o, prepare(o, a), prepare(o, b)) # Perm @propagate_inbounds prepare(o::Perm, i) = (prepare(o.order, o.data[i]), i) lt_prepared(p::Perm, (da, a), (db, b)) = (lt_prepared(p.order, da, db)::Bool) | (!(lt_prepared(p.order, db, da)::Bool) & (a < b)) +@propagate_inbounds lt(::Perm, a, b) = lt_prepared(o, prepare(o, a), prepare(o, b)) ## Lt lt(o::Lt, a, b) = o.lt(a, b) From db70e1de9dfbba9c7cf2c439def399b9f51e1fe8 Mon Sep 17 00:00:00 2001 From: Lilith Orion Hafner Date: Thu, 9 Nov 2023 07:07:46 -0600 Subject: [PATCH 5/5] fixup --- base/ordering.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/base/ordering.jl b/base/ordering.jl index bcb78a2dc6b39..155ce8958c216 100644 --- a/base/ordering.jl +++ b/base/ordering.jl @@ -193,18 +193,18 @@ lt(o::ForwardOrdering, a, b) = isless(a, b) # Reverse prepare(o::ReverseOrdering, x) = prepare(o.fwd, x) lt_prepared(o::ReverseOrdering, a, b) = lt_prepared(o.fwd, b, a) -lt(::ReverseOrdering, a, b) = lt_prepared(o, prepare(o, a), prepare(o, b)) +lt(o::ReverseOrdering, a, b) = lt_prepared(o, prepare(o, a), prepare(o, b)) # By prepare(o::By, x) = prepare(o.order, o.by(x)) lt_prepared(o::By, a, b) = lt_prepared(o.order, a, b) -lt(::By, a, b) = lt_prepared(o, prepare(o, a), prepare(o, b)) +lt(o::By, a, b) = lt_prepared(o, prepare(o, a), prepare(o, b)) # Perm @propagate_inbounds prepare(o::Perm, i) = (prepare(o.order, o.data[i]), i) lt_prepared(p::Perm, (da, a), (db, b)) = (lt_prepared(p.order, da, db)::Bool) | (!(lt_prepared(p.order, db, da)::Bool) & (a < b)) -@propagate_inbounds lt(::Perm, a, b) = lt_prepared(o, prepare(o, a), prepare(o, b)) +@propagate_inbounds lt(o::Perm, a, b) = lt_prepared(o, prepare(o, a), prepare(o, b)) ## Lt lt(o::Lt, a, b) = o.lt(a, b)