Skip to content

Commit 1e5c581

Browse files
author
Pietro Vertechi
authored
Merge pull request #34 from piever/pv/ties
clean up tied indices
2 parents f79f9a8 + 4cec43c commit 1e5c581

File tree

2 files changed

+19
-10
lines changed

2 files changed

+19
-10
lines changed

src/sort.jl

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,23 +14,25 @@ function Base.permute!(c::StructVector, p::AbstractVector)
1414
return c
1515
end
1616

17-
struct TiedIndices{T <: AbstractVector}
17+
struct TiedIndices{T<:AbstractVector, U<:AbstractUnitRange}
1818
vec::T
1919
perm::Vector{Int}
20-
within::Tuple{Int, Int}
20+
within::U
2121
end
2222

2323
TiedIndices(vec::AbstractVector, perm=sortperm(vec)) =
24-
TiedIndices(vec, perm, extrema(axes(vec, 1)))
24+
TiedIndices(vec, perm, axes(vec, 1))
2525

2626
Base.IteratorSize(::Type{<:TiedIndices}) = Base.SizeUnknown()
2727

2828
Base.eltype(::Type{<:TiedIndices{T}}) where {T} =
2929
Pair{eltype(T), UnitRange{Int}}
3030

31-
function Base.iterate(n::TiedIndices, i = n.within[1])
31+
Base.sortperm(t::TiedIndices) = t.perm
32+
33+
function Base.iterate(n::TiedIndices, i = first(n.within))
3234
vec, perm = n.vec, n.perm
33-
l = n.within[2]
35+
l = last(n.within)
3436
i > l && return nothing
3537
row = vec[perm[i]]
3638
i1 = i
@@ -42,10 +44,15 @@ end
4244

4345
tiedindices(args...) = TiedIndices(args...)
4446

45-
function groupindices(args...)
47+
function uniquesorted(args...)
48+
t = tiedindices(args...)
49+
(row for (row, _) in t)
50+
end
51+
52+
function finduniquesorted(args...)
4653
t = tiedindices(args...)
47-
p = t.perm
48-
((row => t.perm[idxs]) for (row, idxs) in t)
54+
p = sortperm(t)
55+
(row => p[idxs] for (row, idxs) in t)
4956
end
5057

5158
function Base.sortperm(c::StructVector{T};
@@ -70,7 +77,7 @@ function refine_perm!(p, cols, c, x, y, lo, hi)
7077
temp = similar(p, 0)
7178
order = Base.Order.By(j->(@inbounds k=y[j]; k))
7279
nc = length(cols)
73-
for (_, idxs) in TiedIndices(x, p, (lo, hi))
80+
for (_, idxs) in TiedIndices(x, p, lo:hi)
7481
i, i1 = extrema(idxs)
7582
if i1 > i
7683
sort_sub_by!(p, i, i1, y, order, temp)

test/runtests.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,9 +73,11 @@ end
7373
s = collect(d)
7474
@test first.(s) == [1, 2, 3]
7575
@test last.(s) == [1:3, 4:4, 5:5]
76-
t = collect(StructArrays.groupindices(c))
76+
t = collect(StructArrays.finduniquesorted(c))
7777
@test first.(t) == [1, 2, 3]
7878
@test last.(t) == [[1, 4, 5], [2], [3]]
79+
u = collect(StructArrays.uniquesorted(c))
80+
@test u == [1, 2, 3]
7981
end
8082

8183
@testset "similar" begin

0 commit comments

Comments
 (0)