Skip to content

Commit 11d43c1

Browse files
authored
Support weighted quantiles in cut (#423)
This requires adding an extension point for StatsBase. Unfortunately more copies of the data and weights are done than necessary as StatsBase does not support in-place weighted quantile! on pre-sorted data nor taking a view of weights vectors (JuliaStats/StatsBase.jl#723).
1 parent e4a13b1 commit 11d43c1

File tree

5 files changed

+79
-9
lines changed

5 files changed

+79
-9
lines changed

Project.toml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,15 @@ Unicode = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5"
1616
Arrow = "69666777-d1a9-59fb-9406-91d4454c9d45"
1717
JSON = "682c06a0-de6a-54ab-a142-c8b1cf79cde6"
1818
RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01"
19+
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
1920
SentinelArrays = "91c51154-3ec4-41a3-a24f-3f23e20d615c"
2021
StructTypes = "856f2bd8-1eba-4b0a-8007-ebc267875bd4"
2122

2223
[extensions]
2324
CategoricalArraysArrowExt = "Arrow"
2425
CategoricalArraysJSONExt = "JSON"
2526
CategoricalArraysRecipesBaseExt = "RecipesBase"
27+
CategoricalArraysStatsBaseExt = "StatsBase"
2628
CategoricalArraysSentinelArraysExt = "SentinelArrays"
2729
CategoricalArraysStructTypesExt = "StructTypes"
2830

@@ -37,6 +39,7 @@ RecipesBase = "1.1"
3739
Requires = "1"
3840
SentinelArrays = "1"
3941
Statistics = "1"
42+
StatsBase = "0.15, 0.16, 0.17, 0.18, 0.19, 0.20, 0.21, 0.22, 0.23, 0.24, 0.25, 0.26, 0.27, 0.28, 0.29, 0.30, 0.31, 0.32, 0.33, 0.34"
4043
StructTypes = "1"
4144
julia = "1.6"
4245

@@ -49,8 +52,9 @@ PooledArrays = "2dfb63ee-cc39-5dd5-95bd-886bf059d720"
4952
RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01"
5053
RecipesPipeline = "01d81517-befc-4cb6-b9ec-a95719d0359c"
5154
SentinelArrays = "91c51154-3ec4-41a3-a24f-3f23e20d615c"
55+
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
5256
StructTypes = "856f2bd8-1eba-4b0a-8007-ebc267875bd4"
5357
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
5458

5559
[targets]
56-
test = ["Arrow", "Dates", "JSON", "JSON3", "PooledArrays", "RecipesBase", "RecipesPipeline", "SentinelArrays", "StructTypes", "Test"]
60+
test = ["Arrow", "Dates", "JSON", "JSON3", "PooledArrays", "RecipesBase", "RecipesPipeline", "SentinelArrays", "StatsBase", "StructTypes", "Test"]

ext/CategoricalArraysStatsBaseExt.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
module CategoricalArraysStatsBaseExt
2+
3+
if isdefined(Base, :get_extension)
4+
import CategoricalArrays: _wquantile
5+
using StatsBase
6+
else
7+
import ..CategoricalArrays: _wquantile
8+
using ..StatsBase
9+
end
10+
11+
_wquantile(x::AbstractArray, w::AbstractWeights, p::AbstractVector) = quantile(x, w, p)
12+
13+
end

src/CategoricalArrays.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ module CategoricalArrays
4545
@require JSON="682c06a0-de6a-54ab-a142-c8b1cf79cde6" include("../ext/CategoricalArraysJSONExt.jl")
4646
@require RecipesBase="3cdcf5f2-1ef4-517c-9805-6587b60abb01" include("../ext/CategoricalArraysRecipesBaseExt.jl")
4747
@require SentinelArrays="91c51154-3ec4-41a3-a24f-3f23e20d615c" include("../ext/CategoricalArraysSentinelArraysExt.jl")
48+
@require StatsBase="2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" include("../ext/CategoricalArraysStatsBaseExt.jl")
4849
@require StructTypes="856f2bd8-1eba-4b0a-8007-ebc267875bd4" include("../ext/CategoricalArraysStructTypesExt.jl")
4950
end
5051
end

src/extras.jl

Lines changed: 36 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -337,11 +337,17 @@ function find_breaks(v::AbstractVector, qs::AbstractVector)
337337
return breaks
338338
end
339339

340+
# AbstractWeights method is defined in StatsBase extension
341+
# There is no in-place weighted quantile method in StatsBase
342+
_wquantile(x::AbstractArray, w::AbstractVector, p::AbstractVector) =
343+
throw(ArgumentError("`weights` must be an `AbstractWeights` vector from StatsBase.jl"))
344+
340345
"""
341346
cut(x::AbstractArray, ngroups::Integer;
342347
labels::Union{AbstractVector{<:AbstractString},Function},
343348
sigdigits::Integer=3,
344-
allowempty::Bool=false)
349+
allowempty::Bool=false,
350+
weights::Union{AbstractWeights, Nothing}=nothing)
345351
346352
Cut a numeric array into `ngroups` quantiles.
347353
@@ -373,19 +379,41 @@ quantiles.
373379
other than the last one are equal, generating empty intervals;
374380
when `true`, duplicate breaks are allowed and the intervals they generate are kept as
375381
unused levels (but duplicate labels are not allowed).
382+
* `weights::Union{AbstractWeights, Nothing}=nothing`: observations weights to used when
383+
computing quantiles (see `quantile` documentation in StatsBase).
376384
"""
377385
function cut(x::AbstractArray, ngroups::Integer;
378386
labels::Union{AbstractVector{<:SupportedTypes},Function,Nothing}=nothing,
379387
sigdigits::Integer=3,
380-
allowempty::Bool=false)
388+
allowempty::Bool=false,
389+
weights::Union{AbstractVector, Nothing}=nothing)
381390
ngroups >= 1 || throw(ArgumentError("ngroups must be strictly positive (got $ngroups)"))
382-
sorted_x = eltype(x) >: Missing ? sort!(collect(skipmissing(x))) : sort(x)
383-
min_x, max_x = first(sorted_x), last(sorted_x)
384-
if (min_x isa Number && isnan(min_x)) ||
385-
(max_x isa Number && isnan(max_x))
386-
throw(ArgumentError("NaN values are not allowed in input vector"))
391+
if weights === nothing
392+
sorted_x = eltype(x) >: Missing ? sort!(collect(skipmissing(x))) : sort(x)
393+
min_x, max_x = first(sorted_x), last(sorted_x)
394+
if (min_x isa Number && isnan(min_x)) ||
395+
(max_x isa Number && isnan(max_x))
396+
throw(ArgumentError("NaN values are not allowed in input vector"))
397+
end
398+
qs = quantile!(sorted_x, (1:(ngroups-1))/ngroups, sorted=true)
399+
else
400+
if eltype(x) >: Missing
401+
nm_inds = findall(!ismissing, x)
402+
nm_x = view(x, nm_inds)
403+
# TODO: use a view once this is supported (JuliaStats/StatsBase.jl#723)
404+
nm_weights = weights[nm_inds]
405+
else
406+
nm_x = x
407+
nm_weights = weights
408+
end
409+
sorted_x = sort(nm_x)
410+
min_x, max_x = first(sorted_x), last(sorted_x)
411+
if (min_x isa Number && isnan(min_x)) ||
412+
(max_x isa Number && isnan(max_x))
413+
throw(ArgumentError("NaN values are not allowed in input vector"))
414+
end
415+
qs = _wquantile(nm_x, nm_weights, (1:(ngroups-1))/ngroups)
387416
end
388-
qs = quantile!(sorted_x, (1:(ngroups-1))/ngroups, sorted=true)
389417
breaks = [min_x; find_breaks(sorted_x, qs); max_x]
390418
if !allowempty && !allunique(@view breaks[1:end-1])
391419
throw(ArgumentError("cannot compute $ngroups quantiles due to " *

test/15_extras.jl

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
module TestExtras
22
using Test
33
using CategoricalArrays
4+
using StatsBase
5+
using Missings
46

57
const = isequal
68

@@ -423,4 +425,26 @@ end
423425

424426
end
425427

428+
@testset "cut with weighted quantiles" begin
429+
@test_throws ArgumentError cut(1:3, 3, weights=1:3)
430+
431+
x = collect(Float64, 1:100)
432+
w = fweights(repeat(1:10, inner=10))
433+
y = cut(x, 10, weights=w)
434+
@test levelcode.(y) == levelcode.(cut(x, quantile(x, w, (0:10)./10)))
435+
@test levels(y) == ["[1, 29)", "[29, 43)", "[43, 53)", "[53, 62)", "[62, 70)",
436+
"[70, 77)", "[77, 83)", "[83, 89)", "[89, 95)", "[95, 100]"]
437+
438+
mx = allowmissing(x)
439+
mx[2] = mx[10] = missing
440+
nm_inds = .!ismissing.(mx)
441+
y = cut(mx, 10, weights=w)
442+
@test levelcode.(y) levelcode.(cut(mx, quantile(x[nm_inds], w[nm_inds], (0:10)./10)))
443+
@test levels(y) == ["[1, 30)", "[30, 43)", "[43, 53)", "[53, 62)", "[62, 70)",
444+
"[70, 77)", "[77, 83)", "[83, 89)", "[89, 95)", "[95, 100]"]
445+
446+
x[5] = NaN
447+
@test_throws ArgumentError cut(x, 3, weights=w)
448+
end
449+
426450
end

0 commit comments

Comments
 (0)