Skip to content

Commit fbdc7b3

Browse files
brendanluglemaitre
andauthored
MAINT Parameters validation for sklearn.utils.class_weight.compute_class_weight (scikit-learn#26512)
Co-authored-by: Guillaume Lemaitre <g.lemaitre58@gmail.com>
1 parent a5cc3ab commit fbdc7b3

File tree

3 files changed

+64
-40
lines changed

3 files changed

+64
-40
lines changed

sklearn/tests/test_public_functions.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -306,13 +306,13 @@ def _check_function_param_validation(
306306
"sklearn.tree.export_graphviz",
307307
"sklearn.tree.export_text",
308308
"sklearn.tree.plot_tree",
309-
"sklearn.utils.class_weight.compute_sample_weight",
310309
"sklearn.utils.gen_batches",
311310
"sklearn.utils.gen_even_slices",
312-
"sklearn.utils.graph.single_source_shortest_path_length",
313311
"sklearn.utils.resample",
314312
"sklearn.utils.safe_mask",
315313
"sklearn.utils.extmath.randomized_svd",
314+
"sklearn.utils.class_weight.compute_class_weight",
315+
"sklearn.utils.class_weight.compute_sample_weight",
316316
"sklearn.utils.graph.single_source_shortest_path_length",
317317
]
318318

sklearn/utils/class_weight.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -8,29 +8,37 @@
88
from ._param_validation import StrOptions, validate_params
99

1010

11+
@validate_params(
12+
{
13+
"class_weight": [dict, StrOptions({"balanced"}), None],
14+
"classes": [np.ndarray],
15+
"y": ["array-like"],
16+
},
17+
prefer_skip_nested_validation=True,
18+
)
1119
def compute_class_weight(class_weight, *, classes, y):
1220
"""Estimate class weights for unbalanced datasets.
1321
1422
Parameters
1523
----------
16-
class_weight : dict, 'balanced' or None
17-
If 'balanced', class weights will be given by
18-
``n_samples / (n_classes * np.bincount(y))``.
19-
If a dictionary is given, keys are classes and values
20-
are corresponding class weights.
21-
If None is given, the class weights will be uniform.
24+
class_weight : dict, "balanced" or None
25+
If "balanced", class weights will be given by
26+
`n_samples / (n_classes * np.bincount(y))`.
27+
If a dictionary is given, keys are classes and values are corresponding class
28+
weights.
29+
If `None` is given, the class weights will be uniform.
2230
2331
classes : ndarray
2432
Array of the classes occurring in the data, as given by
25-
``np.unique(y_org)`` with ``y_org`` the original class labels.
33+
`np.unique(y_org)` with `y_org` the original class labels.
2634
2735
y : array-like of shape (n_samples,)
2836
Array of original class labels per sample.
2937
3038
Returns
3139
-------
3240
class_weight_vect : ndarray of shape (n_classes,)
33-
Array with class_weight_vect[i] the weight for i-th class.
41+
Array with `class_weight_vect[i]` the weight for i-th class.
3442
3543
References
3644
----------
@@ -57,10 +65,6 @@ def compute_class_weight(class_weight, *, classes, y):
5765
else:
5866
# user-defined dictionary
5967
weight = np.ones(classes.shape[0], dtype=np.float64, order="C")
60-
if not isinstance(class_weight, dict):
61-
raise ValueError(
62-
"class_weight must be dict, 'balanced', or None, got: %r" % class_weight
63-
)
6468
unweighted_classes = []
6569
for i, c in enumerate(classes):
6670
if c in class_weight:

sklearn/utils/tests/test_class_weight.py

Lines changed: 46 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -22,33 +22,53 @@ def test_compute_class_weight():
2222
assert cw[0] < cw[1] < cw[2]
2323

2424

25-
def test_compute_class_weight_not_present():
25+
@pytest.mark.parametrize(
26+
"y_type, class_weight, classes, err_msg",
27+
[
28+
(
29+
"numeric",
30+
"balanced",
31+
np.arange(4),
32+
"classes should have valid labels that are in y",
33+
),
34+
# Non-regression for https://github.com/scikit-learn/scikit-learn/issues/8312
35+
(
36+
"numeric",
37+
{"label_not_present": 1.0},
38+
np.arange(4),
39+
r"The classes, \[0, 1, 2, 3\], are not in class_weight",
40+
),
41+
(
42+
"numeric",
43+
"balanced",
44+
np.arange(2),
45+
"classes should include all valid labels",
46+
),
47+
(
48+
"numeric",
49+
{0: 1.0, 1: 2.0},
50+
np.arange(2),
51+
"classes should include all valid labels",
52+
),
53+
(
54+
"string",
55+
{"dogs": 3, "cat": 2},
56+
np.array(["dog", "cat"]),
57+
r"The classes, \['dog'\], are not in class_weight",
58+
),
59+
],
60+
)
61+
def test_compute_class_weight_not_present(y_type, class_weight, classes, err_msg):
2662
# Raise error when y does not contain all class labels
27-
classes = np.arange(4)
28-
y = np.asarray([0, 0, 0, 1, 1, 2])
29-
with pytest.raises(ValueError):
30-
compute_class_weight("balanced", classes=classes, y=y)
31-
# Fix exception in error message formatting when missing label is a string
32-
# https://github.com/scikit-learn/scikit-learn/issues/8312
33-
with pytest.raises(
34-
ValueError, match=r"The classes, \[0, 1, 2, 3\], are not in class_weight"
35-
):
36-
compute_class_weight({"label_not_present": 1.0}, classes=classes, y=y)
37-
# Raise error when y has items not in classes
38-
classes = np.arange(2)
39-
with pytest.raises(ValueError):
40-
compute_class_weight("balanced", classes=classes, y=y)
41-
with pytest.raises(ValueError):
42-
compute_class_weight({0: 1.0, 1: 2.0}, classes=classes, y=y)
43-
44-
# y contains a unweighted class that is not in class_weights
45-
classes = np.asarray(["cat", "dog"])
46-
y = np.asarray(["dog", "cat", "dog"])
47-
class_weights = {"dogs": 3, "cat": 2}
48-
msg = r"The classes, \['dog'\], are not in class_weight"
49-
50-
with pytest.raises(ValueError, match=msg):
51-
compute_class_weight(class_weights, classes=classes, y=y)
63+
y = (
64+
np.asarray([0, 0, 0, 1, 1, 2])
65+
if y_type == "numeric"
66+
else np.asarray(["dog", "cat", "dog"])
67+
)
68+
69+
print(y)
70+
with pytest.raises(ValueError, match=err_msg):
71+
compute_class_weight(class_weight, classes=classes, y=y)
5272

5373

5474
def test_compute_class_weight_dict():

0 commit comments

Comments
 (0)