Skip to content

Simplify default cut labels #422

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
May 21, 2025
Merged
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ CategoricalArraysStructTypesExt = "StructTypes"

[compat]
Arrow = "2"
Compat = "3.37, 4"
Compat = "3.47, 4.10"
DataAPI = "1.6"
JSON = "0.15, 0.16, 0.17, 0.18, 0.19, 0.20, 0.21"
JSON3 = "1.1.2"
Expand Down
4 changes: 3 additions & 1 deletion src/CategoricalArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,12 @@ module CategoricalArrays
import DataAPI: unwrap
export unwrap

using Compat
@compat public default_formatter, numbered_formatter

using DataAPI
using Missings
using Printf
import Compat

# JuliaLang/julia#36810
if VERSION < v"1.5.2"
Expand Down
213 changes: 166 additions & 47 deletions src/extras.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,37 +27,86 @@ function fill_refs!(refs::AbstractArray, X::AbstractArray,
end
end

const CUT_FMT = Printf.Format("%.*g")

"""
default_formatter(from, to, i; leftclosed, rightclosed)
CategoricalArrays.default_formatter(from, to, i::Integer;
leftclosed::Bool, rightclosed::Bool,
sigdigits::Integer)

Provide the default label format for the `cut(x, breaks)` method,
which is `"[from, to)"` if `leftclosed` is `true` and `"[from, to)"` otherwise.

Provide the default label format for the `cut(x, breaks)` method.
If they are floating points values, breaks are turned into to strings using
`@sprintf("%.*g", sigdigits, break)`
(or `to` using `@sprintf("%.*g", sigdigits, break)` for the last break).
"""
default_formatter(from, to, i; leftclosed, rightclosed) =
string(leftclosed ? "[" : "(", from, ", ", to, rightclosed ? "]" : ")")
function default_formatter(from, to, i::Integer;
leftclosed::Bool, rightclosed::Bool,
sigdigits::Integer)
from_str = from isa AbstractFloat ?
Printf.format(CUT_FMT, sigdigits, from) :
string(from)
to_str = to isa AbstractFloat ?
Printf.format(CUT_FMT, sigdigits, to) :
string(to)
string(leftclosed ? "[" : "(", from_str, ", ", to_str, rightclosed ? "]" : ")")
end

"""
CategoricalArrays.numbered_formatter(from, to, i::Integer;
leftclosed::Bool, rightclosed::Bool,
sigdigits::Integer)

Provide the default label format for the `cut(x, ngroups)` method
when `allowempty=true`, which is `"i: [from, to)"` if `leftclosed`
is `true` and `"i: [from, to)"` otherwise.

If they are floating points values, breaks are turned into to strings using
`@sprintf("%.*g", sigdigits, breaks)`
(or `to` using `@sprintf("%.*g", sigdigits, break)` for the last break).
"""
numbered_formatter(from, to, i::Integer;
leftclosed::Bool, rightclosed::Bool,
sigdigits::Integer) =
string(i, ": ",
default_formatter(from, to, i, leftclosed=leftclosed, rightclosed=rightclosed,
sigdigits=sigdigits))

@doc raw"""
cut(x::AbstractArray, breaks::AbstractVector;
labels::Union{AbstractVector,Function},
sigdigits::Integer=3,
extend::Union{Bool,Missing}=false, allowempty::Bool=false)

Cut a numeric array into intervals at values `breaks`
and return an ordered `CategoricalArray` indicating
the interval into which each entry falls. Intervals are of the form `[lower, upper)`,
i.e. the lower bound is included and the upper bound is excluded, except
the interval into which each entry falls. Intervals are of the form `[lower, upper)`
(closed on the left), i.e. the lower bound is included and the upper bound is excluded, except
the last interval, which is closed on both ends, i.e. `[lower, upper]`.

If `x` accepts missing values (i.e. `eltype(x) >: Missing`) the returned array will
also accept them.

!!! note
For floating point data, breaks may be rounded to `sigdigits` significant digits
when generating interval labels, meaning that they may not reflect exactly the cutpoints
used.

# Keyword arguments
* `extend::Union{Bool, Missing}=false`: when `false`, an error is raised if some values
in `x` fall outside of the breaks; when `true`, breaks are automatically added to include
all values in `x`; when `missing`, values outside of the breaks generate `missing` entries.
* `labels::Union{AbstractVector, Function}`: a vector of strings, characters
or numbers giving the names to use for
the intervals; or a function `f(from, to, i; leftclosed, rightclosed)` that generates
or numbers giving the names to use for the intervals; or a function
`f(from, to, i::Integer; leftclosed::Bool, rightclosed::Bool, sigdigits::Integer)` that generates
the labels from the left and right interval boundaries and the group index. Defaults to
`"[from, to)"` (or `"[from, to]"` for the rightmost interval if `extend == true`).
[`CategoricalArrays.default_formatter`](@ref), giving `"[from, to)"` (or `"[from, to]"`
for the rightmost interval if `extend == true`).
* `sigdigits::Integer=3`: the minimum number of significant digits to use in labels.
This value is increased automatically if necessary so that rounded breaks are unique.
Only used for floating point types and when `labels` is a function, in which case it
is passed to it as a keyword argument.
* `allowempty::Bool=false`: when `false`, an error is raised if some breaks other than
the last one appear multiple times, generating empty intervals; when `true`,
duplicate breaks are allowed and the intervals they generate are kept as
Expand All @@ -69,27 +118,27 @@ julia> using CategoricalArrays

julia> cut(-1:0.5:1, [0, 1], extend=true)
5-element CategoricalArray{String,1,UInt32}:
"[-1.0, 0.0)"
"[-1.0, 0.0)"
"[0.0, 1.0]"
"[0.0, 1.0]"
"[0.0, 1.0]"
"[-1, 0)"
"[-1, 0)"
"[0, 1]"
"[0, 1]"
"[0, 1]"

julia> cut(-1:0.5:1, 2)
5-element CategoricalArray{String,1,UInt32}:
"Q1: [-1.0, 0.0)"
"Q1: [-1.0, 0.0)"
"Q2: [0.0, 1.0]"
"Q2: [0.0, 1.0]"
"Q2: [0.0, 1.0]"
"[-1, 0)"
"[-1, 0)"
"[0, 1]"
"[0, 1]"
"[0, 1]"

julia> cut(-1:0.5:1, 2, labels=["A", "B"])
5-element CategoricalArray{String,1,UInt32}:
"A"
"A"
"B"
"B"
"B"
"B"

julia> cut(-1:0.5:1, 2, labels=[-0.5, +0.5])
5-element CategoricalArray{Float64,1,UInt32}:
Expand All @@ -104,16 +153,17 @@ fmt (generic function with 1 method)

julia> cut(-1:0.5:1, 3, labels=fmt)
5-element CategoricalArray{String,1,UInt32}:
"grp 1 (-1.0//-0.3333333333333335)"
"grp 1 (-1.0//-0.3333333333333335)"
"grp 2 (-0.3333333333333335//0.33333333333333326)"
"grp 3 (0.33333333333333326//1.0)"
"grp 3 (0.33333333333333326//1.0)"
"grp 1 (-1.0//0.0)"
"grp 1 (-1.0//0.0)"
"grp 2 (0.0//0.5)"
"grp 3 (0.5//1.0)"
"grp 3 (0.5//1.0)"
```
"""
@inline function cut(x::AbstractArray, breaks::AbstractVector;
extend::Union{Bool, Missing}=false,
labels::Union{AbstractVector{<:SupportedTypes},Function}=default_formatter,
sigdigits::Integer=3,
allowmissing::Union{Bool, Nothing}=nothing,
allow_missing::Union{Bool, Nothing}=nothing,
allowempty::Bool=false)
Expand All @@ -127,14 +177,15 @@ julia> cut(-1:0.5:1, 3, labels=fmt)
:cut)
extend = missing
end
return _cut(x, breaks, extend, labels, allowempty)
return _cut(x, breaks, extend, labels, sigdigits, allowempty)
end

# Separate function for inferability (thanks to inlining of cut)
function _cut(x::AbstractArray{T, N}, breaks::AbstractVector,
extend::Union{Bool, Missing},
labels::Union{AbstractVector{<:SupportedTypes},Function},
allowempty::Bool=false) where {T, N}
sigdigits::Integer,
allowempty::Bool) where {T, N}
if !issorted(breaks)
breaks = sort(breaks)
end
Expand Down Expand Up @@ -191,21 +242,55 @@ function _cut(x::AbstractArray{T, N}, breaks::AbstractVector,
end
end

# Find minimal number of digits so that distinct breaks remain so
if eltype(breaks) <: AbstractFloat
while true
local i
for outer i in 2:lastindex(breaks)
b1 = breaks[i-1]
b2 = breaks[i]
isequal(b1, b2) && continue

b1_str = Printf.format(CUT_FMT, sigdigits, b1)
b2_str = Printf.format(CUT_FMT, sigdigits, b2)
if b1_str == b2_str
sigdigits += 1
break
end
end
i == lastindex(breaks) && break
end
end
n = length(breaks)
n >= 2 || throw(ArgumentError("at least two breaks must be provided when extend is not true"))
if labels isa Function
from = breaks[1:n-1]
to = breaks[2:n]
firstlevel = labels(from[1], to[1], 1,
leftclosed=!isequal(breaks[1], breaks[2]), rightclosed=false)
local firstlevel
try
firstlevel = labels(from[1], to[1], 1,
leftclosed=!isequal(breaks[1], breaks[2]), rightclosed=false,
sigdigits=sigdigits)
catch
# Support functions defined before v1.0, where sigdigits did not exist
Base.depwarn("`labels` function is now required to accept a `sigdigits` keyword argument",
:cut)
labels_orig = labels
labels = (from, to, i; leftclosed, rightclosed, sigdigits) ->
labels_orig(from, to, i; leftclosed, rightclosed)
firstlevel = labels_orig(from[1], to[1], 1,
leftclosed=!isequal(breaks[1], breaks[2]), rightclosed=false)
end
levs = Vector{typeof(firstlevel)}(undef, n-1)
levs[1] = firstlevel
for i in 2:n-2
levs[i] = labels(from[i], to[i], i,
leftclosed=!isequal(breaks[i], breaks[i+1]), rightclosed=false)
leftclosed=!isequal(breaks[i], breaks[i+1]), rightclosed=false,
sigdigits=sigdigits)
end
levs[end] = labels(from[end], to[end], n-1,
leftclosed=true, rightclosed=true)
leftclosed=true, rightclosed=true,
sigdigits=sigdigits)
else
length(labels) == n-1 ||
throw(ArgumentError("labels must be of length $(n-1), but got length $(length(labels))"))
Expand All @@ -226,52 +311,86 @@ function _cut(x::AbstractArray{T, N}, breaks::AbstractVector,
end

"""
quantile_formatter(from, to, i; leftclosed, rightclosed)

Provide the default label format for the `cut(x, ngroups)` method.
Find first value in (sorted) `v` which is greater than or equal to each quantile
in (sorted) `qs`.
"""
quantile_formatter(from, to, i; leftclosed, rightclosed) =
string("Q", i, ": ", leftclosed ? "[" : "(", from, ", ", to, rightclosed ? "]" : ")")
function find_breaks(v::AbstractVector, qs::AbstractVector)
n = length(qs)
breaks = similar(v, n)
n == 0 && return breaks

i = 1
q = qs[1]
@inbounds for x in v
# Use isless and isequal to differentiate -0.0 from 0.0
if isless(q, x) || isequal(q, x)
breaks[i] = x
i += 1
i > n && break
q = qs[i]
end
end
return breaks
end

"""
cut(x::AbstractArray, ngroups::Integer;
labels::Union{AbstractVector{<:AbstractString},Function},
sigdigits::Integer=3,
allowempty::Bool=false)

Cut a numeric array into `ngroups` quantiles, determined using `quantile`.
Cut a numeric array into `ngroups` quantiles.

This is equivalent to `cut(x, quantile(x, (0:ngroups)/ngroups))`,
but breaks are taken from actual data values instead of estimated quantiles.

If `x` contains `missing` values, they are automatically skipped when computing
quantiles.

!!! note
For floating point data, breaks may be rounded to `sigdigits` significant digits
when generating interval labels, meaning that they may not reflect exactly the cutpoints
used.

# Keyword arguments
* `labels::Union{AbstractVector, Function}`: a vector of strings, characters
or numbers giving the names to use for
the intervals; or a function `f(from, to, i; leftclosed, rightclosed)` that generates
or numbers giving the names to use for the intervals; or a function
`f(from, to, i::Integer; leftclosed::Bool, rightclosed::Bool, sigdigits::Integer)` that generates
the labels from the left and right interval boundaries and the group index. Defaults to
`"Qi: [from, to)"` (or `"Qi: [from, to]"` for the rightmost interval).
[`CategoricalArrays.default_formatter`](@ref), giving `"[from, to)"` (or `"[from, to]"`
for the rightmost interval if `extend == true`) if `allowempty=false`, otherwise to
[`CategoricalArrays.numbered_formatter`](@ref), which prefixes the label with the quantile
number to ensure uniqueness.
* `sigdigits::Integer=3`: the minimum number of significant digits to use when rounding
breaks for inclusion in generated labels. This value is increased automatically if necessary
so that rounded breaks are unique. Only used for floating point types and when `labels` is a
function, in which case it is passed to it as a keyword argument.
* `allowempty::Bool=false`: when `false`, an error is raised if some quantiles breakpoints
other than the last one are equal, generating empty intervals;
when `true`, duplicate breaks are allowed and the intervals they generate are kept as
unused levels (but duplicate labels are not allowed).
"""
function cut(x::AbstractArray, ngroups::Integer;
labels::Union{AbstractVector{<:SupportedTypes},Function}=quantile_formatter,
labels::Union{AbstractVector{<:SupportedTypes},Function,Nothing}=nothing,
sigdigits::Integer=3,
allowempty::Bool=false)
ngroups >= 1 || throw(ArgumentError("ngroups must be strictly positive (got $ngroups)"))
xnm = eltype(x) >: Missing ? skipmissing(x) : x
# Computing extrema is faster than taking 0 and 1 quantiles
min_x, max_x = extrema(xnm)
sorted_x = eltype(x) >: Missing ? sort!(collect(skipmissing(x))) : sort(x)
min_x, max_x = first(sorted_x), last(sorted_x)
if (min_x isa Number && isnan(min_x)) ||
(max_x isa Number && isnan(max_x))
throw(ArgumentError("NaN values are not allowed in input vector"))
end
breaks = quantile(xnm, (1:ngroups-1)/ngroups)
breaks = [min_x; breaks; max_x]
qs = quantile!(sorted_x, (1:(ngroups-1))/ngroups, sorted=true)
breaks = [min_x; find_breaks(sorted_x, qs); max_x]
if !allowempty && !allunique(@view breaks[1:end-1])
throw(ArgumentError("cannot compute $ngroups quantiles due to " *
"too many duplicated values in `x`. " *
"Pass `allowempty=true` to allow empty quantiles or " *
"choose a lower value for `ngroups`."))
end
cut(x, breaks; labels=labels, allowempty=allowempty)
if labels === nothing
labels = allowempty ? numbered_formatter : default_formatter
end
return cut(x, breaks; labels=labels, sigdigits=sigdigits, allowempty=allowempty)
end
Loading
Loading