Skip to content

Commit 8bfc647

Browse files
authored
fix unique() behaviour, add unique!() (#358)
so it conforms to the semantics of the Base.unique() This is a breaking change that requires a new minor release.
1 parent f313cb0 commit 8bfc647

File tree

4 files changed

+59
-30
lines changed

4 files changed

+59
-30
lines changed

src/array.jl

Lines changed: 26 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
## Code for CategoricalArray
22

33
import Base: Array, convert, collect, copy, getindex, setindex!, similar, size,
4-
unique, vcat, in, summary, float, complex, copyto!
4+
unique, unique!, vcat, in, summary, float, complex, copyto!
55

66
# Used for keyword argument default value
77
_isordered(x::AbstractCategoricalArray) = isordered(x)
@@ -867,31 +867,36 @@ function levels!(A::CategoricalArray{T, N, R}, newlevels::AbstractVector;
867867
return A
868868
end
869869

870-
function _unique(::Type{S},
871-
refs::AbstractArray{T},
872-
pool::CategoricalPool) where {S, T<:Integer}
873-
nlevels = length(levels(pool)) + 1
874-
order = fill(0, nlevels) # 0 indicates not seen
875-
# If we don't track missings, short-circuit even if none has been seen
876-
count = S >: Missing ? 0 : 1
877-
@inbounds for i in refs
878-
if order[i + 1] == 0
879-
count += 1
880-
order[i + 1] = count
881-
count == nlevels && break
870+
# return unique refs (each value is unique) in the order of appearance in `refs`
871+
# equivalent to fallback Base.unique() implementation,
872+
# but short-circuits once references to all levels are encountered
873+
function _uniquerefs(A::CatArrOrSub{T}) where T
874+
arefs = refs(A)
875+
res = similar(arefs, 0)
876+
nlevels = length(levels(A))
877+
maxunique = nlevels + (T >: Missing ? 1 : 0)
878+
seen = fill(false, nlevels + 1) # always +1 for 0 (missing ref)
879+
@inbounds for ref in arefs
880+
if !seen[ref + 1]
881+
push!(res, ref)
882+
seen[ref + 1] = true
883+
(length(res) == maxunique) && break
882884
end
883885
end
884-
S[i == 1 ? missing : levels(pool)[i - 1] for i in sortperm(order) if order[i] != 0]
886+
return res
885887
end
886888

887-
"""
888-
unique(A::CategoricalArray)
889+
unique(A::CatArrOrSub{T}) where T =
890+
CategoricalVector{T}(_uniquerefs(A), copy(pool(A)))
889891

890-
Return levels which appear in `A` in their order of appearance.
891-
This function is significantly slower than [`levels`](@ref DataAPI.levels)
892-
since it needs to check whether levels are used or not.
893-
"""
894-
unique(A::CategoricalArray{T}) where {T} = _unique(T, A.refs, A.pool)
892+
function unique!(A::CategoricalVector)
893+
urefs = _uniquerefs(A)
894+
if length(urefs) != length(A)
895+
resize!(A.refs, length(urefs))
896+
copyto!(A.refs, urefs)
897+
end
898+
return A
899+
end
895900

896901
"""
897902
droplevels!(A::CategoricalArray)

src/subarray.jl

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,6 @@ isordered(sa::SubArray{T,N,P}) where {T,N,P<:CategoricalArray} = isordered(paren
55
levels!(sa::SubArray{T,N,P}, newlevels::Vector) where {T,N,P<:CategoricalArray} =
66
levels!(parent(sa), newlevels)
77

8-
function unique(sa::SubArray{T,N,P}) where {T,N,P<:CategoricalArray}
9-
A = parent(sa)
10-
refs = view(A.refs, sa.indices...)
11-
S = eltype(P) >: Missing ? Union{eltype(levels(A.pool)), Missing} : eltype(levels(A.pool))
12-
_unique(S, refs, A.pool)
13-
end
14-
158
refs(A::SubArray{<:Any, <:Any, <:CategoricalArray}) =
169
view(parent(A).refs, parentindices(A)...)
1710

test/11_array.jl

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ using CategoricalArrays: DefaultRefType, leveltype
1616
@test isordered(x) === ordered
1717
@test levels(x) == sort(unique(a))
1818
@test unique(x) == unique(a)
19+
@test typeof(unique(x)) === typeof(x)
1920
@test size(x) === (3,)
2021
@test length(x) === 3
2122

@@ -272,6 +273,7 @@ using CategoricalArrays: DefaultRefType, leveltype
272273
@test x == collect(a)
273274
@test isordered(x) === ordered
274275
@test levels(x) == unique(x) == unique(a)
276+
@test typeof(unique(x)) === typeof(x)
275277
@test size(x) === (4,)
276278
@test length(x) === 4
277279
@test leveltype(x) === Float64
@@ -437,6 +439,7 @@ using CategoricalArrays: DefaultRefType, leveltype
437439
@test x[4] === CategoricalValue(x.pool, 4)
438440
@test levels(x) == unique(a)
439441
@test unique(x) == unique(collect(x))
442+
@test typeof(unique(x)) === typeof(x)
440443

441444
x[1:2] .= -1
442445
@test x[1] === CategoricalValue(x.pool, 5)
@@ -473,6 +476,7 @@ using CategoricalArrays: DefaultRefType, leveltype
473476
@test x == a
474477
@test isordered(x) === ordered
475478
@test levels(x) == unique(x) == unique(a)
479+
@test unique(x) isa CategoricalVector{String, R}
476480
@test size(x) === (2, 3)
477481
@test length(x) === 6
478482

@@ -729,27 +733,42 @@ end
729733
@test levels!(x, ["Young", "Middle", "Old"]) === x
730734
@test levels(x) == ["Young", "Middle", "Old"]
731735
@test unique(x) == ["Old", "Young", "Middle"]
736+
@test typeof(unique(x)) === typeof(x)
732737
@test levels!(x, ["Young", "Middle", "Old", "Unused"]) === x
733738
@test levels(x) == ["Young", "Middle", "Old", "Unused"]
734739
@test unique(x) == ["Old", "Young", "Middle"]
735740
@test levels!(x, ["Unused1", "Young", "Middle", "Old", "Unused2"]) === x
736741
@test levels(x) == ["Unused1", "Young", "Middle", "Old", "Unused2"]
737742
@test unique(x) == ["Old", "Young", "Middle"]
738743

744+
y = copy(x)
745+
@test unique!(y) === y
746+
@test y == unique(x)
747+
739748
x = CategoricalArray(String[])
740749
@test isa(levels(x), Vector{String}) && isempty(levels(x))
741-
@test isa(unique(x), Vector{String}) && isempty(unique(x))
750+
@test isa(unique(x), typeof(x)) && isempty(unique(x))
742751
@test levels!(x, ["Young", "Middle", "Old"]) === x
743752
@test levels(x) == ["Young", "Middle", "Old"]
744-
@test isa(unique(x), Vector{String}) && isempty(unique(x))
753+
@test isa(unique(x), typeof(x)) && isempty(unique(x))
754+
755+
y = copy(x)
756+
@test unique!(y) === y
757+
@test y == unique(x)
745758

746759
# To test short-circuiting
747760
x = CategoricalArray(repeat(1:10, inner=10))
748761
@test levels(x) == collect(1:10)
749762
@test unique(x) == collect(1:10)
763+
@test unique(x) isa typeof(x)
750764
@test levels!(x, [19:-1:1; 20]) === x
751765
@test levels(x) == [19:-1:1; 20]
752766
@test unique(x) == collect(1:10)
767+
@test unique(x) isa typeof(x)
768+
769+
y = copy(x)
770+
@test unique!(y) === y
771+
@test y == 1:10
753772
end
754773

755774
end

test/12_missingarray.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,14 @@ const ≅ = isequal
1919
@test isordered(x) === ordered
2020
@test levels(x) == sort(unique(a))
2121
@test unique(x) == unique(a)
22+
@test typeof(unique(x)) === typeof(x)
2223
@test size(x) === (3,)
2324
@test length(x) === 3
2425

26+
y = copy(x)
27+
@test y === unique!(y)
28+
@test y == unique(x)
29+
2530
@test convert(CategoricalArray, x) === x
2631
@test convert(CategoricalArray{Union{String, Missing}}, x) === x
2732
@test convert(CategoricalArray{Union{String, Missing}, 1}, x) === x
@@ -296,6 +301,7 @@ const ≅ = isequal
296301
@test x a
297302
@test levels(x) == filter(x->!ismissing(x), unique(a))
298303
@test unique(x) unique(a)
304+
@test typeof(unique(x)) === typeof(x)
299305
@test size(x) === (3,)
300306
@test length(x) === 3
301307

@@ -440,6 +446,7 @@ const ≅ = isequal
440446
@test x == collect(a)
441447
@test isordered(x) === ordered
442448
@test levels(x) == unique(x) == unique(a)
449+
@test typeof(unique(x)) === typeof(x)
443450
@test size(x) === (4,)
444451
@test length(x) === 4
445452
@test leveltype(x) === Float64
@@ -616,6 +623,7 @@ const ≅ = isequal
616623
@test x[4] === CategoricalValue(x.pool, 4)
617624
@test levels(x) == unique(a)
618625
@test unique(x) == unique(collect(x))
626+
@test typeof(unique(x)) === typeof(x)
619627

620628
x[1:2] .= -1
621629
@test x[1] === CategoricalValue(x.pool, 5)
@@ -625,6 +633,7 @@ const ≅ = isequal
625633
@test isordered(x) === false
626634
@test levels(x) == vcat(unique(a), -1)
627635
@test unique(x) == unique(collect(x))
636+
@test typeof(unique(x)) === typeof(x)
628637

629638

630639
ordered!(x, ordered)
@@ -656,6 +665,7 @@ const ≅ = isequal
656665
@test x == a
657666
@test isordered(x) === ordered
658667
@test levels(x) == unique(x) == unique(a)
668+
@test unique(x) isa CategoricalVector{Union{String, Missing}, R}
659669
@test size(x) === (2, 3)
660670
@test length(x) === 6
661671

@@ -816,6 +826,7 @@ const ≅ = isequal
816826
@test isordered(x) === ordered
817827
@test levels(x) == filter(x->!ismissing(x), unique(a))
818828
@test unique(x) unique(a)
829+
@test unique(x) isa CategoricalVector{Union{String, Missing}, R}
819830
@test size(x) === (2, 3)
820831
@test length(x) === 6
821832

@@ -1137,6 +1148,7 @@ end
11371148
x = CategoricalArray(["Old", "Young", "Middle", missing, "Young"])
11381149
@test levels(x) == ["Middle", "Old", "Young"]
11391150
@test unique(x) ["Old", "Young", "Middle", missing]
1151+
@test typeof(unique(x)) === typeof(x)
11401152
@test levels!(x, ["Young", "Middle", "Old"]) === x
11411153
@test levels(x) == ["Young", "Middle", "Old"]
11421154
@test unique(x) ["Old", "Young", "Middle", missing]

0 commit comments

Comments
 (0)