Skip to content

Commit 4b12853

Browse files
committed
Make levels return a CategoricalArray
Having `levels` preserve the eltype of the input is sometimes useful to write generic code. This is only slightly breaking as the result still compares equal to the previous behvior returning unwrapped values.
1 parent 11d43c1 commit 4b12853

15 files changed

+134
-88
lines changed

benchmark/benchmarks.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,11 +55,11 @@ SUITE["many levels"]["CategoricalArray(::Vector{String})"] =
5555
a = rand([@sprintf("id%010d", k) for k in 1:1000], 10000)
5656
ca = CategoricalArray(a)
5757

58-
levs = levels(ca)
58+
levs = unwrap.(levels(ca))
5959
SUITE["many levels"]["levels! with original levels"] =
6060
@benchmarkable levels!(ca, levs)
6161

62-
levs = reverse(levels(ca))
62+
levs = reverse(unwrap.(levels(ca)))
6363
SUITE["many levels"]["levels! with resorted levels"] =
6464
@benchmarkable levels!(ca, levs)
6565

docs/src/using.md

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ By default, the levels are lexically sorted, which is clearly not correct in our
2020

2121
```jldoctest using
2222
julia> levels(x)
23-
3-element Vector{String}:
23+
3-element CategoricalArray{String,1,UInt32}:
2424
"Middle"
2525
"Old"
2626
"Young"
@@ -68,7 +68,7 @@ To get rid of the `"Old"` group, just call the [`droplevels!`](@ref) function:
6868

6969
```jldoctest using
7070
julia> levels(x)
71-
3-element Vector{String}:
71+
3-element CategoricalArray{String,1,UInt32}:
7272
"Young"
7373
"Middle"
7474
"Old"
@@ -81,7 +81,7 @@ julia> droplevels!(x)
8181
"Young"
8282
8383
julia> levels(x)
84-
2-element Vector{String}:
84+
2-element CategoricalArray{String,1,UInt32}:
8585
"Young"
8686
"Middle"
8787
@@ -139,7 +139,7 @@ Levels still need to be reordered manually:
139139

140140
```jldoctest using
141141
julia> levels(y)
142-
3-element Vector{String}:
142+
3-element CategoricalArray{String,1,UInt32}:
143143
"Middle"
144144
"Old"
145145
"Young"
@@ -263,15 +263,15 @@ true
263263
Likewise, assigning a `CategoricalValue` from `y` to an entry in `x` expands the levels of `x` with all levels from `y`, *respecting the ordering of levels of both vectors if possible*:
264264
```jldoctest using
265265
julia> levels(x)
266-
2-element Vector{String}:
266+
2-element CategoricalArray{String,1,UInt32}:
267267
"Middle"
268268
"Old"
269269
270270
julia> x[1] = y[1]
271271
CategoricalValue{String, UInt32} "Young" (1/2)
272272
273273
julia> levels(x)
274-
3-element Vector{String}:
274+
3-element CategoricalArray{String,1,UInt32}:
275275
"Young"
276276
"Middle"
277277
"Old"
@@ -296,7 +296,7 @@ julia> ab = vcat(a, b)
296296
"c"
297297
298298
julia> levels(ab)
299-
3-element Vector{String}:
299+
3-element CategoricalArray{String,1,UInt32}:
300300
"a"
301301
"b"
302302
"c"
@@ -320,7 +320,7 @@ julia> ab2 = vcat(a, b)
320320
"c"
321321
322322
julia> levels(ab2)
323-
3-element Vector{String}:
323+
3-element CategoricalArray{String,1,UInt32}:
324324
"a"
325325
"b"
326326
"c"

ext/CategoricalArraysArrowExt.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@ import Arrow: ArrowTypes
77
const CATARRAY_ARROWNAME = Symbol("JuliaLang.CategoricalArrays.CategoricalArray")
88
ArrowTypes.arrowname(::Type{<:CategoricalValue}) = CATARRAY_ARROWNAME
99
ArrowTypes.arrowmetadata(::Type{CategoricalValue{T, R}}) where {T, R} = string(R)
10+
ArrowTypes.ArrowType(::Type{<:CategoricalValue{T}}) where {T} = T
11+
ArrowTypes.toarrow(x::CategoricalValue) = unwrap(x)
1012

1113
ArrowTypes.arrowname(::Type{Union{<:CategoricalValue, Missing}}) = CATARRAY_ARROWNAME
1214
ArrowTypes.arrowmetadata(::Type{Union{CategoricalValue{T, R}, Missing}}) where {T, R} =

ext/CategoricalArraysRecipesBaseExt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ else
99
end
1010

1111
RecipesBase.@recipe function f(::Type{T}, v::T) where T <: CategoricalValue
12-
level_strings = [map(string, levels(v)); missing]
12+
level_strings = [map(string, CategoricalArrays._levels(v)); missing]
1313
ticks --> eachindex(level_strings)
1414
v -> ismissing(v) ? length(level_strings) : Int(CategoricalArrays.refcode(v)),
1515
i -> level_strings[Int(i)]

src/array.jl

Lines changed: 29 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -239,7 +239,7 @@ function CategoricalArray{T, N, R}(A::CategoricalArray{S, N, Q};
239239
catch err
240240
err isa LevelsException || rethrow(err)
241241
throw(ArgumentError("encountered value(s) not in specified `levels`: " *
242-
"$(setdiff(CategoricalArrays.levels(res), levels))"))
242+
"$(setdiff(_levels(res), levels))"))
243243
end
244244
end
245245
return res
@@ -358,18 +358,18 @@ function _convert(::Type{CategoricalArray{T, N, R}}, A::AbstractArray{S, N};
358358
copyto!(res, A)
359359

360360
if levels !== nothing
361-
CategoricalArrays.levels(res) == levels ||
361+
_levels(res) == levels ||
362362
throw(ArgumentError("encountered value(s) not in specified `levels`: " *
363-
"$(setdiff(CategoricalArrays.levels(res), levels))"))
363+
"$(setdiff(_levels(res), levels))"))
364364
else
365365
# if order is defined for level type, automatically apply it
366366
L = leveltype(res)
367367
if Base.OrderStyle(L) isa Base.Ordered
368-
levels!(res, sort(CategoricalArrays.levels(res)))
368+
levels!(res, sort(_levels(res)))
369369
elseif hasmethod(isless, (L, L))
370370
# isless may throw an error, e.g. for AbstractArray{T} of unordered T
371371
try
372-
levels!(res, sort(CategoricalArrays.levels(res)))
372+
levels!(res, sort(_levels(res)))
373373
catch e
374374
e isa MethodError || rethrow(e)
375375
end
@@ -382,7 +382,7 @@ end
382382
# From CategoricalArray (preserve levels, ordering and R)
383383
function convert(::Type{CategoricalArray{T, N, R}}, A::CategoricalArray{S, N}) where {S, T, N, R}
384384
if length(A.pool) > typemax(R)
385-
throw(LevelsException{T, R}(levels(A)[typemax(R)+1:end]))
385+
throw(LevelsException{T, R}(_levels(A)[typemax(R)+1:end]))
386386
end
387387

388388
if !(T >: Missing) && S >: Missing && any(iszero, A.refs)
@@ -460,7 +460,7 @@ size(A::CategoricalArray) = size(A.refs)
460460
Base.IndexStyle(::Type{<:CategoricalArray}) = IndexLinear()
461461

462462
function update_refs!(A::CategoricalArray, newlevels::AbstractVector)
463-
oldlevels = levels(A)
463+
oldlevels = _levels(A)
464464
levelsmap = similar(A.refs, length(oldlevels)+1)
465465
# 0 maps to a missing value
466466
levelsmap[1] = 0
@@ -478,7 +478,7 @@ function merge_pools!(A::CatArrOrSub,
478478
updaterefs::Bool=true,
479479
updatepool::Bool=true)
480480
newlevels, ordered = merge_pools(pool(A), pool(B))
481-
oldlevels = levels(A)
481+
oldlevels = _levels(A)
482482
pA = A isa SubArray ? parent(A) : A
483483
ordered!(pA, ordered)
484484
# If A's levels are an ordered superset of new (merged) pool, no need to recompute refs
@@ -537,8 +537,8 @@ function copyto!(dest::CatArrOrSub{T, N, R}, dstart::Integer,
537537

538538
# try converting src to dest type to avoid partial copy corruption of dest
539539
# in the event that the src cannot be copied into dest
540-
slevs = convert(Vector{T}, levels(src))
541-
dlevs = levels(dest)
540+
slevs = convert(Vector{T}, _levels(src))
541+
dlevs = _levels(dest)
542542
if eltype(src) >: Missing && !(eltype(dest) >: Missing) && !all(x -> x > 0, srefs)
543543
throw(MissingException("cannot copy array with missing values to an array with element type $T"))
544544
end
@@ -591,7 +591,7 @@ function copyto!(dest::CatArrOrSub{T1, N, R}, dstart::Integer,
591591
return invoke(copyto!, Tuple{AbstractArray, Integer, AbstractArray, Integer, Integer},
592592
dest, dstart, src, sstart, n)
593593
end
594-
newdestlevs = destlevs = copy(levels(dest)) # copy since we need original levels below
594+
newdestlevs = destlevs = copy(_levels(dest)) # copy since we need original levels below
595595
srclevsnm = T2 >: Missing ? setdiff(srclevs, [missing]) : srclevs
596596
if !(srclevsnm destlevs)
597597
# if order is defined for level type, automatically apply it
@@ -701,7 +701,7 @@ While this will reduce memory use, this function is type-unstable, which can aff
701701
performance inside the function where the call is made. Therefore, use it with caution.
702702
"""
703703
function compress(A::CategoricalArray{T, N}) where {T, N}
704-
R = reftype(length(levels(A.pool)))
704+
R = reftype(length(_levels(A.pool)))
705705
convert(CategoricalArray{T, N, R}, A)
706706
end
707707

@@ -719,11 +719,11 @@ decompress(A::CategoricalArray{T, N}) where {T, N} =
719719
convert(CategoricalArray{T, N, DefaultRefType}, A)
720720

721721
function vcat(A::CategoricalArray...)
722-
ordered = any(isordered, A) && all(a->isordered(a) || isempty(levels(a)), A)
723-
newlevels, ordered = mergelevels(ordered, map(levels, A)...)
722+
ordered = any(isordered, A) && all(a->isordered(a) || isempty(_levels(a)), A)
723+
newlevels, ordered = mergelevels(ordered, map(_levels, A)...)
724724

725725
refsvec = map(A) do a
726-
ii = convert(Vector{Int}, indexin(levels(a.pool), newlevels))
726+
ii = convert(Vector{Int}, indexin(_levels(a.pool), newlevels))
727727
[x==0 ? 0 : ii[x] for x in a.refs]::Array{Int,ndims(a)}
728728
end
729729

@@ -761,23 +761,25 @@ This may include levels which do not actually appear in the data
761761
`missing` will be included only if it appears in the data and
762762
`skipmissing=false` is passed.
763763
764-
The returned vector is an internal field of `x` which must not be mutated
764+
The returned vector is owned by `x` and must not be mutated
765765
as doing so would corrupt it.
766766
"""
767-
@inline function DataAPI.levels(A::CatArrOrSub{T}; skipmissing::Bool=true) where T
767+
@inline function DataAPI.levels(A::CatArrOrSub; skipmissing::Bool=true)
768768
if eltype(A) >: Missing && !skipmissing
769769
if any(==(0), refs(A))
770-
T[levels(pool(A)); missing]
770+
eltype(A)[levels(pool(A)); missing]
771771
else
772-
convert(Vector{T}, levels(pool(A)))
772+
levels_missing(pool(A))
773773
end
774774
else
775775
levels(pool(A))
776776
end
777777
end
778778

779+
_levels(A::CatArrOrSub) = _levels(pool(A))
780+
779781
"""
780-
levels!(A::CategoricalArray, newlevels::Vector; allowmissing::Bool=false)
782+
levels!(A::CategoricalArray, newlevels::AbstractVector; allowmissing::Bool=false)
781783
782784
Set the levels categorical array `A`. The order of appearance of levels will be respected
783785
by [`levels`](@ref DataAPI.levels), which may affect display of results in some operations; if `A` is
@@ -791,7 +793,7 @@ Else, `newlevels` must include all levels which appear in the data.
791793
"""
792794
function levels!(A::CategoricalArray{T, N, R}, newlevels::AbstractVector;
793795
allowmissing::Bool=false) where {T, N, R}
794-
(levels(A) == newlevels) && return A # nothing to do
796+
(_levels(A) == newlevels) && return A # nothing to do
795797

796798
# map each new level to its ref code
797799
newlv2ref = Dict{eltype(newlevels), Int}()
@@ -806,7 +808,7 @@ function levels!(A::CategoricalArray{T, N, R}, newlevels::AbstractVector;
806808
end
807809

808810
# map each old ref code to new ref code (or 0 if no such level)
809-
oldlevels = levels(pool(A))
811+
oldlevels = _levels(pool(A))
810812
oldref2newref = fill(0, length(oldlevels) + 1)
811813
for (i, lv) in enumerate(oldlevels)
812814
oldref2newref[i + 1] = get(newlv2ref, lv, 0)
@@ -867,7 +869,7 @@ end
867869
function _uniquerefs(A::CatArrOrSub{T}) where T
868870
arefs = refs(A)
869871
res = similar(arefs, 0)
870-
nlevels = length(levels(A))
872+
nlevels = length(_levels(A))
871873
maxunique = nlevels + (T >: Missing ? 1 : 0)
872874
seen = fill(false, nlevels + 1) # always +1 for 0 (missing ref)
873875
@inbounds for ref in arefs
@@ -900,7 +902,7 @@ returned by [`levels`](@ref DataAPI.levels)).
900902
"""
901903
function droplevels!(A::CategoricalArray)
902904
arefs = refs(A)
903-
nlevels = length(levels(A)) + 1 # +1 for missing
905+
nlevels = length(_levels(A)) + 1 # +1 for missing
904906
seen = fill(false, nlevels)
905907
seen[1] = true # assume that missing is always observed to simplify checks
906908
nseen = 1
@@ -913,7 +915,7 @@ function droplevels!(A::CategoricalArray)
913915
end
914916

915917
# replace the pool
916-
A.pool = typeof(pool(A))(@inbounds(levels(A)[view(seen, 2:nlevels)]), isordered(A))
918+
A.pool = typeof(pool(A))(@inbounds(_levels(A)[view(seen, 2:nlevels)]), isordered(A))
917919
# recode refs to keep only the seen ones (optimized version of update_refs!())
918920
seen[1] = false # to start levelsmap from 0
919921
levelsmap = cumsum(seen)
@@ -1030,7 +1032,7 @@ end
10301032
ordered=_isordered(A),
10311033
compress::Bool=false) where {T, N, R}
10321034
# @inline is needed so that return type is inferred when compress is not provided
1033-
RefType = compress ? reftype(length(CategoricalArrays.levels(A))) : R
1035+
RefType = compress ? reftype(length(_levels(A))) : R
10341036
CategoricalArray{T, N, RefType}(A, levels=levels, ordered=ordered)
10351037
end
10361038

@@ -1043,7 +1045,7 @@ function in(x::CategoricalValue, y::CategoricalArray{T, N, R}) where {T, N, R}
10431045
if x.pool === y.pool
10441046
return refcode(x) in y.refs
10451047
else
1046-
ref = get(y.pool, levels(x.pool)[refcode(x)], zero(R))
1048+
ref = get(y.pool, _levels(x.pool)[refcode(x)], zero(R))
10471049
return ref != 0 ? ref in y.refs : false
10481050
end
10491051
end

src/pool.jl

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@ Base.convert(::Type{CategoricalPool{S}}, pool::CategoricalPool{T, R}) where {S,
2121
convert(CategoricalPool{S, R}, pool)
2222

2323
function Base.convert(::Type{CategoricalPool{T, R}}, pool::CategoricalPool) where {T, R <: Integer}
24-
if length(levels(pool)) > typemax(R)
25-
throw(LevelsException{T, R}(levels(pool)[typemax(R)+1:end]))
24+
if length(pool.levels) > typemax(R)
25+
throw(LevelsException{T, R}(pool.levels[typemax(R)+1:end]))
2626
end
2727

2828
levelsT = convert(Vector{T}, pool.levels)
@@ -37,10 +37,10 @@ Base.copy(pool::CategoricalPool{T, R}) where {T, R} =
3737
function Base.show(io::IO, pool::CategoricalPool{T, R}) where {T, R}
3838
@static if VERSION >= v"1.6.0"
3939
@printf(io, "%s{%s, %s}([%s])", CategoricalPool, T, R,
40-
join(map(repr, levels(pool)), ", "))
40+
join(map(repr, pool.levels), ", "))
4141
else
4242
@printf(io, "%s{%s,%s}([%s])", CategoricalPool, T, R,
43-
join(map(repr, levels(pool)), ", "))
43+
join(map(repr, pool.levels), ", "))
4444
end
4545

4646
pool.ordered && print(io, " with ordered levels")
@@ -65,6 +65,7 @@ it doesn't do this itself to avoid doing a dict lookup twice
6565

6666
i = R(n + 1)
6767
push!(pool.levels, x)
68+
push!(pool.levelsinds, i)
6869
pool_hash = pool.hash
6970
if pool_hash !== nothing
7071
pool.hash = hash(x, pool_hash)
@@ -185,10 +186,10 @@ function merge_pools(a::CategoricalPool{T}, b::CategoricalPool) where {T}
185186
newlevs = T[]
186187
ordered = isordered(a)
187188
elseif length(a) == 0
188-
newlevs = Vector{T}(levels(b))
189+
newlevs = Vector{T}(b.levels)
189190
ordered = isordered(b)
190191
elseif length(b) == 0
191-
newlevs = copy(levels(a))
192+
newlevs = copy(a.levels)
192193
ordered = isordered(a)
193194
else
194195
ordered = isordered(a) && (isordered(b) || b a)
@@ -200,7 +201,7 @@ end
200201

201202
@inline function Base.hash(pool::CategoricalPool, h::UInt)
202203
if pool.hash === nothing
203-
pool.hash = hashlevels(levels(pool))
204+
pool.hash = hashlevels(pool.levels)
204205
end
205206
hash(pool.hash, h)
206207
end
@@ -246,9 +247,9 @@ end
246247

247248
# Contrary to the CategoricalArray one, this method only allows adding new levels at the end
248249
# so that existing CategoricalValue objects still point to the same value
249-
function levels!(pool::CategoricalPool{S, R}, newlevels::Vector;
250+
function levels!(pool::CategoricalPool{S, R}, newlevels::AbstractVector;
250251
checkunique::Bool=true) where {S, R}
251-
levs = convert(Vector{S}, newlevels)
252+
levs = newlevels isa CategoricalVector{S} ? newlevels : convert(Vector{S}, newlevels)
252253
if checkunique && !allunique(levs)
253254
throw(ArgumentError(string("duplicated levels found in levs: ",
254255
join(unique(filter(x->sum(levs.==x)>1, levs)), ", "))))
@@ -259,24 +260,30 @@ function levels!(pool::CategoricalPool{S, R}, newlevels::Vector;
259260
n = length(levs)
260261

261262
if n > typemax(R)
262-
throw(LevelsException{S, R}(setdiff(levs, levels(pool))[typemax(R)-length(levels(pool))+1:end]))
263+
throw(LevelsException{S, R}(setdiff(levs, pool.levels)[typemax(R)-length(pool.levels)+1:end]))
263264
end
264265

265266
empty!(pool.invindex)
266267
resize!(pool.levels, n)
268+
resize!(pool.levelsinds, n)
267269
pool.hash = nothing
268270
pool.equalto = C_NULL
269271
pool.subsetof = C_NULL
270272
for i in 1:n
271273
v = levs[i]
272274
pool.levels[i] = v
275+
pool.levelsinds[i] = i
273276
pool.invindex[v] = i
274277
end
275278

276279
return pool
277280
end
278281

279-
DataAPI.levels(pool::CategoricalPool) = pool.levels
282+
DataAPI.levels(pool::CategoricalPool{T}) where {T} =
283+
CategoricalVector{T}(pool.levelsinds, pool)
284+
levels_missing(pool::CategoricalPool{T}) where {T} =
285+
CategoricalVector{Union{T, Missing}}(pool.levelsinds, pool)
286+
_levels(pool::CategoricalPool) = pool.levels
280287

281288
isordered(pool::CategoricalPool) = pool.ordered
282289
ordered!(pool::CategoricalPool, ordered) = (pool.ordered = ordered; pool)

0 commit comments

Comments
 (0)