Skip to content

Commit bb261bf

Browse files
EmilyXinyiogrisellucyleeow
authored
Add array API support for _weighted_percentile (scikit-learn#29431)
Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org> Co-authored-by: Lucy Liu <jliu176@gmail.com>
1 parent a6efcaf commit bb261bf

File tree

2 files changed

+170
-58
lines changed

2 files changed

+170
-58
lines changed

sklearn/utils/stats.py

Lines changed: 57 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
# Authors: The scikit-learn developers
22
# SPDX-License-Identifier: BSD-3-Clause
33

4-
import numpy as np
4+
from ..utils._array_api import (
5+
_find_matching_floating_dtype,
6+
get_namespace_and_device,
7+
)
58

6-
from .extmath import stable_cumsum
79

8-
9-
def _weighted_percentile(array, sample_weight, percentile_rank=50):
10+
def _weighted_percentile(array, sample_weight, percentile_rank=50, xp=None):
1011
"""Compute the weighted percentile with method 'inverted_cdf'.
1112
1213
When the percentile lies between two data points of `array`, the function returns
@@ -37,72 +38,86 @@ def _weighted_percentile(array, sample_weight, percentile_rank=50):
3738
The probability level of the percentile to compute, in percent. Must be between
3839
0 and 100.
3940
41+
xp : array_namespace, default=None
42+
The standard-compatible namespace for `array`. Default: infer.
43+
4044
Returns
4145
-------
42-
percentile : int if `array` 1D, ndarray if `array` 2D
46+
percentile : scalar or 0D array if `array` 1D (or 0D), array if `array` 2D
4347
Weighted percentile at the requested probability level.
4448
"""
49+
xp, _, device = get_namespace_and_device(array)
50+
# `sample_weight` should follow `array` for dtypes
51+
floating_dtype = _find_matching_floating_dtype(array, xp=xp)
52+
array = xp.asarray(array, dtype=floating_dtype, device=device)
53+
sample_weight = xp.asarray(sample_weight, dtype=floating_dtype, device=device)
54+
4555
n_dim = array.ndim
4656
if n_dim == 0:
47-
return array[()]
57+
return array
4858
if array.ndim == 1:
49-
array = array.reshape((-1, 1))
59+
array = xp.reshape(array, (-1, 1))
5060
# When sample_weight 1D, repeat for each array.shape[1]
5161
if array.shape != sample_weight.shape and array.shape[0] == sample_weight.shape[0]:
52-
sample_weight = np.tile(sample_weight, (array.shape[1], 1)).T
53-
62+
sample_weight = xp.tile(sample_weight, (array.shape[1], 1)).T
5463
# Sort `array` and `sample_weight` along axis=0:
55-
sorted_idx = np.argsort(array, axis=0)
56-
sorted_weights = np.take_along_axis(sample_weight, sorted_idx, axis=0)
64+
sorted_idx = xp.argsort(array, axis=0)
65+
sorted_weights = xp.take_along_axis(sample_weight, sorted_idx, axis=0)
5766

58-
# Set NaN values in `sample_weight` to 0. We only perform this operation if NaN
59-
# values are present at all to avoid temporary allocations of size `(n_samples,
60-
# n_features)`. If NaN values were present, they would sort to the end (which we can
61-
# observe from `sorted_idx`).
67+
# Set NaN values in `sample_weight` to 0. Only perform this operation if NaN
68+
# values present to avoid temporary allocations of size `(n_samples, n_features)`.
6269
n_features = array.shape[1]
63-
largest_value_per_column = array[sorted_idx[-1, ...], np.arange(n_features)]
64-
if np.isnan(largest_value_per_column).any():
65-
sorted_nan_mask = np.take_along_axis(np.isnan(array), sorted_idx, axis=0)
70+
largest_value_per_column = array[
71+
sorted_idx[-1, ...], xp.arange(n_features, device=device)
72+
]
73+
# NaN values get sorted to end (largest value)
74+
if xp.any(xp.isnan(largest_value_per_column)):
75+
sorted_nan_mask = xp.take_along_axis(xp.isnan(array), sorted_idx, axis=0)
6676
sorted_weights[sorted_nan_mask] = 0
6777

6878
# Compute the weighted cumulative distribution function (CDF) based on
69-
# sample_weight and scale percentile_rank along it:
70-
weight_cdf = stable_cumsum(sorted_weights, axis=0)
71-
adjusted_percentile_rank = percentile_rank / 100 * weight_cdf[-1]
72-
73-
# For percentile_rank=0, ignore leading observations with sample_weight=0; see
74-
# PR #20528:
79+
# `sample_weight` and scale `percentile_rank` along it.
80+
#
81+
# Note: we call `xp.cumulative_sum` on the transposed `sorted_weights` to
82+
# ensure that the result is of shape `(n_features, n_samples)` so
83+
# `xp.searchsorted` calls take contiguous inputs as a result (for
84+
# performance reasons).
85+
weight_cdf = xp.cumulative_sum(sorted_weights.T, axis=1)
86+
adjusted_percentile_rank = percentile_rank / 100 * weight_cdf[..., -1]
87+
88+
# Ignore leading `sample_weight=0` observations when `percentile_rank=0` (#20528)
7589
mask = adjusted_percentile_rank == 0
76-
adjusted_percentile_rank[mask] = np.nextafter(
90+
adjusted_percentile_rank[mask] = xp.nextafter(
7791
adjusted_percentile_rank[mask], adjusted_percentile_rank[mask] + 1
7892
)
79-
80-
# Find index (i) of `adjusted_percentile` in `weight_cdf`,
81-
# such that weight_cdf[i-1] < percentile <= weight_cdf[i]
82-
percentile_idx = np.array(
93+
# For each feature with index j, find sample index i of the scalar value
94+
# `adjusted_percentile_rank[j]` in 1D array `weight_cdf[j]`, such that:
95+
# weight_cdf[j, i-1] < adjusted_percentile_rank[j] <= weight_cdf[j, i].
96+
percentile_indices = xp.asarray(
8397
[
84-
np.searchsorted(weight_cdf[:, i], adjusted_percentile_rank[i])
85-
for i in range(weight_cdf.shape[1])
86-
]
98+
xp.searchsorted(
99+
weight_cdf[feature_idx, ...], adjusted_percentile_rank[feature_idx]
100+
)
101+
for feature_idx in range(weight_cdf.shape[0])
102+
],
103+
device=device,
87104
)
88-
89-
# In rare cases, percentile_idx equals to sorted_idx.shape[0]:
105+
# In rare cases, `percentile_indices` equals to `sorted_idx.shape[0]`
90106
max_idx = sorted_idx.shape[0] - 1
91-
percentile_idx = np.apply_along_axis(
92-
lambda x: np.clip(x, 0, max_idx), axis=0, arr=percentile_idx
93-
)
107+
percentile_indices = xp.clip(percentile_indices, 0, max_idx)
108+
109+
col_indices = xp.arange(array.shape[1], device=device)
110+
percentile_in_sorted = sorted_idx[percentile_indices, col_indices]
94111

95-
col_indices = np.arange(array.shape[1])
96-
percentile_in_sorted = sorted_idx[percentile_idx, col_indices]
97112
result = array[percentile_in_sorted, col_indices]
98113

99114
return result[0] if n_dim == 1 else result
100115

101116

102117
# TODO: refactor to do the symmetrisation inside _weighted_percentile to avoid
103118
# sorting the input array twice.
104-
def _averaged_weighted_percentile(array, sample_weight, percentile_rank=50):
119+
def _averaged_weighted_percentile(array, sample_weight, percentile_rank=50, xp=None):
105120
return (
106-
_weighted_percentile(array, sample_weight, percentile_rank)
107-
- _weighted_percentile(-array, sample_weight, 100 - percentile_rank)
121+
_weighted_percentile(array, sample_weight, percentile_rank, xp=xp)
122+
- _weighted_percentile(-array, sample_weight, 100 - percentile_rank, xp=xp)
108123
) / 2

sklearn/utils/tests/test_stats.py

Lines changed: 113 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,14 @@
33
from numpy.testing import assert_allclose, assert_array_equal
44
from pytest import approx
55

6+
from sklearn._config import config_context
7+
from sklearn.utils._array_api import (
8+
_convert_to_numpy,
9+
get_namespace,
10+
yield_namespace_device_dtype_combinations,
11+
)
12+
from sklearn.utils._array_api import device as array_device
13+
from sklearn.utils.estimator_checks import _array_api_for_tests
614
from sklearn.utils.fixes import np_version, parse_version
715
from sklearn.utils.stats import _averaged_weighted_percentile, _weighted_percentile
816

@@ -39,6 +47,7 @@ def test_averaged_and_weighted_percentile():
3947

4048

4149
def test_weighted_percentile():
50+
"""Check `weighted_percentile` on artificial data with obvious median."""
4251
y = np.empty(102, dtype=np.float64)
4352
y[:50] = 0
4453
y[-51:] = 2
@@ -51,15 +60,16 @@ def test_weighted_percentile():
5160

5261

5362
def test_weighted_percentile_equal():
63+
"""Check `weighted_percentile` with all weights equal to 1."""
5464
y = np.empty(102, dtype=np.float64)
5565
y.fill(0.0)
5666
sw = np.ones(102, dtype=np.float64)
57-
sw[-1] = 0.0
58-
value = _weighted_percentile(y, sw, 50)
59-
assert value == 0
67+
score = _weighted_percentile(y, sw, 50)
68+
assert approx(score) == 0
6069

6170

6271
def test_weighted_percentile_zero_weight():
72+
"""Check `weighted_percentile` with all weights equal to 0."""
6373
y = np.empty(102, dtype=np.float64)
6474
y.fill(1.0)
6575
sw = np.ones(102, dtype=np.float64)
@@ -69,6 +79,11 @@ def test_weighted_percentile_zero_weight():
6979

7080

7181
def test_weighted_percentile_zero_weight_zero_percentile():
82+
"""Check `weighted_percentile(percentile_rank=0)` behaves correctly.
83+
84+
Ensures that (leading)zero-weight observations ignored when `percentile_rank=0`.
85+
See #20528 for details.
86+
"""
7287
y = np.array([0, 1, 2, 3, 4, 5])
7388
sw = np.array([0, 0, 1, 1, 1, 0])
7489
value = _weighted_percentile(y, sw, 0)
@@ -82,18 +97,18 @@ def test_weighted_percentile_zero_weight_zero_percentile():
8297

8398

8499
def test_weighted_median_equal_weights():
85-
# Checks that `_weighted_percentile` and `np.median` (both at probability level=0.5
86-
# and with `sample_weights` being all 1s) return the same percentiles if the number
87-
# of the samples in the data is odd. In this special case, `_weighted_percentile`
88-
# always falls on a precise value (not on the next lower value) and is thus equal to
89-
# `np.median`.
90-
# As discussed in #17370, a similar check with an even number of samples does not
91-
# consistently hold, since then the lower of two percentiles might be selected,
92-
# while the median might lie in between.
100+
"""Checks `_weighted_percentile(percentile_rank=50)` is the same as `np.median`.
101+
102+
`sample_weights` are all 1s and the number of samples is odd.
103+
When number of samples is odd, `_weighted_percentile` always falls on a single
104+
observation (not between 2 values, in which case the lower value would be taken)
105+
and is thus equal to `np.median`.
106+
For an even number of samples, this check will not always hold as (note that
107+
for some other percentile methods it will always hold). See #17370 for details.
108+
"""
93109
rng = np.random.RandomState(0)
94110
x = rng.randint(10, size=11)
95111
weights = np.ones(x.shape)
96-
97112
median = np.median(x)
98113
w_median = _weighted_percentile(x, weights)
99114
assert median == approx(w_median)
@@ -106,10 +121,8 @@ def test_weighted_median_integer_weights():
106121
x = rng.randint(20, size=10)
107122
weights = rng.choice(5, size=10)
108123
x_manual = np.repeat(x, weights)
109-
110124
median = np.median(x_manual)
111125
w_median = _weighted_percentile(x, weights)
112-
113126
assert median == approx(w_median)
114127

115128

@@ -125,8 +138,7 @@ def test_weighted_percentile_2d():
125138
w_median = _weighted_percentile(x_2d, w1)
126139
p_axis_0 = [_weighted_percentile(x_2d[:, i], w1) for i in range(x_2d.shape[1])]
127140
assert_allclose(w_median, p_axis_0)
128-
129-
# Check when array and sample_weight boht 2D
141+
# Check when array and sample_weight both 2D
130142
w2 = rng.choice(5, size=10)
131143
w_2d = np.vstack((w1, w2)).T
132144

@@ -137,6 +149,91 @@ def test_weighted_percentile_2d():
137149
assert_allclose(w_median, p_axis_0)
138150

139151

152+
@pytest.mark.parametrize(
153+
"array_namespace, device, dtype_name", yield_namespace_device_dtype_combinations()
154+
)
155+
@pytest.mark.parametrize(
156+
"data, weights, percentile",
157+
[
158+
# NumPy scalars input (handled as 0D arrays on array API)
159+
(np.float32(42), np.int32(1), 50),
160+
# Random 1D array, constant weights
161+
(lambda rng: rng.rand(50), np.ones(50).astype(np.int32), 50),
162+
# Random 2D array and random 1D weights
163+
(lambda rng: rng.rand(50, 3), lambda rng: rng.rand(50).astype(np.float32), 75),
164+
# Random 2D array and random 2D weights
165+
(
166+
lambda rng: rng.rand(20, 3),
167+
lambda rng: rng.rand(20, 3).astype(np.float32),
168+
25,
169+
),
170+
# zero-weights and `rank_percentile=0` (#20528) (`sample_weight` dtype: int64)
171+
(np.array([0, 1, 2, 3, 4, 5]), np.array([0, 0, 1, 1, 1, 0]), 0),
172+
# np.nan's in data and some zero-weights (`sample_weight` dtype: int64)
173+
(np.array([np.nan, np.nan, 0, 3, 4, 5]), np.array([0, 1, 1, 1, 1, 0]), 0),
174+
# `sample_weight` dtype: int32
175+
(
176+
np.array([0, 1, 2, 3, 4, 5]),
177+
np.array([0, 1, 1, 1, 1, 0], dtype=np.int32),
178+
25,
179+
),
180+
],
181+
)
182+
def test_weighted_percentile_array_api_consistency(
183+
global_random_seed, array_namespace, device, dtype_name, data, weights, percentile
184+
):
185+
"""Check `_weighted_percentile` gives consistent results with array API."""
186+
if array_namespace == "array_api_strict":
187+
try:
188+
import array_api_strict
189+
except ImportError:
190+
pass
191+
else:
192+
if device == array_api_strict.Device("device1"):
193+
# See https://github.com/data-apis/array-api-strict/issues/134
194+
pytest.xfail(
195+
"array_api_strict has bug when indexing with tuple of arrays "
196+
"on non-'CPU_DEVICE' devices."
197+
)
198+
199+
xp = _array_api_for_tests(array_namespace, device)
200+
201+
# Skip test for percentile=0 edge case (#20528) on namespace/device where
202+
# xp.nextafter is broken. This is the case for torch with MPS device:
203+
# https://github.com/pytorch/pytorch/issues/150027
204+
zero = xp.zeros(1, device=device)
205+
one = xp.ones(1, device=device)
206+
if percentile == 0 and xp.all(xp.nextafter(zero, one) == zero):
207+
pytest.xfail(f"xp.nextafter is broken on {device}")
208+
209+
rng = np.random.RandomState(global_random_seed)
210+
X_np = data(rng) if callable(data) else data
211+
weights_np = weights(rng) if callable(weights) else weights
212+
# Ensure `data` of correct dtype
213+
X_np = X_np.astype(dtype_name)
214+
215+
result_np = _weighted_percentile(X_np, weights_np, percentile)
216+
# Convert to Array API arrays
217+
X_xp = xp.asarray(X_np, device=device)
218+
weights_xp = xp.asarray(weights_np, device=device)
219+
220+
with config_context(array_api_dispatch=True):
221+
result_xp = _weighted_percentile(X_xp, weights_xp, percentile)
222+
assert array_device(result_xp) == array_device(X_xp)
223+
assert get_namespace(result_xp)[0] == get_namespace(X_xp)[0]
224+
result_xp_np = _convert_to_numpy(result_xp, xp=xp)
225+
226+
assert result_xp_np.dtype == result_np.dtype
227+
assert result_xp_np.shape == result_np.shape
228+
assert_allclose(result_np, result_xp_np)
229+
230+
# Check dtype correct (`sample_weight` should follow `array`)
231+
if dtype_name == "float32":
232+
assert result_xp_np.dtype == result_np.dtype == np.float32
233+
else:
234+
assert result_xp_np.dtype == np.float64
235+
236+
140237
@pytest.mark.parametrize("sample_weight_ndim", [1, 2])
141238
def test_weighted_percentile_nan_filtered(sample_weight_ndim):
142239
"""Test that calling _weighted_percentile on an array with nan values returns

0 commit comments

Comments
 (0)