Skip to content

Commit a5cc3ab

Browse files
rand0wnglemaitre
andauthored
MAINT Added parameter validation for sklearn.utils.class_weight.compute_sample_weight (scikit-learn#26564)
Co-authored-by: Guillaume Lemaitre <g.lemaitre58@gmail.com>
1 parent 87941ab commit a5cc3ab

File tree

3 files changed

+60
-48
lines changed

3 files changed

+60
-48
lines changed

sklearn/tests/test_public_functions.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -306,6 +306,7 @@ def _check_function_param_validation(
306306
"sklearn.tree.export_graphviz",
307307
"sklearn.tree.export_text",
308308
"sklearn.tree.plot_tree",
309+
"sklearn.utils.class_weight.compute_sample_weight",
309310
"sklearn.utils.gen_batches",
310311
"sklearn.utils.gen_even_slices",
311312
"sklearn.utils.graph.single_source_shortest_path_length",

sklearn/utils/class_weight.py

Lines changed: 30 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
import numpy as np
66
from scipy import sparse
77

8+
from ._param_validation import StrOptions, validate_params
9+
810

911
def compute_class_weight(class_weight, *, classes, y):
1012
"""Estimate class weights for unbalanced datasets.
@@ -75,26 +77,34 @@ def compute_class_weight(class_weight, *, classes, y):
7577
return weight
7678

7779

80+
@validate_params(
81+
{
82+
"class_weight": [dict, list, StrOptions({"balanced"}), None],
83+
"y": ["array-like", "sparse matrix"],
84+
"indices": ["array-like", None],
85+
},
86+
prefer_skip_nested_validation=True,
87+
)
7888
def compute_sample_weight(class_weight, y, *, indices=None):
7989
"""Estimate sample weights by class for unbalanced datasets.
8090
8191
Parameters
8292
----------
8393
class_weight : dict, list of dicts, "balanced", or None
84-
Weights associated with classes in the form ``{class_label: weight}``.
94+
Weights associated with classes in the form `{class_label: weight}`.
8595
If not given, all classes are supposed to have weight one. For
8696
multi-output problems, a list of dicts can be provided in the same
8797
order as the columns of y.
8898
8999
Note that for multioutput (including multilabel) weights should be
90100
defined for each class of every column in its own dict. For example,
91101
for four-class multilabel classification weights should be
92-
[{0: 1, 1: 1}, {0: 1, 1: 5}, {0: 1, 1: 1}, {0: 1, 1: 1}] instead of
93-
[{1:1}, {2:5}, {3:1}, {4:1}].
102+
`[{0: 1, 1: 1}, {0: 1, 1: 5}, {0: 1, 1: 1}, {0: 1, 1: 1}]` instead of
103+
`[{1:1}, {2:5}, {3:1}, {4:1}]`.
94104
95-
The "balanced" mode uses the values of y to automatically adjust
105+
The `"balanced"` mode uses the values of y to automatically adjust
96106
weights inversely proportional to class frequencies in the input data:
97-
``n_samples / (n_classes * np.bincount(y))``.
107+
`n_samples / (n_classes * np.bincount(y))`.
98108
99109
For multi-output, the weights of each column of y will be multiplied.
100110
@@ -103,15 +113,15 @@ def compute_sample_weight(class_weight, y, *, indices=None):
103113
104114
indices : array-like of shape (n_subsample,), default=None
105115
Array of indices to be used in a subsample. Can be of length less than
106-
n_samples in the case of a subsample, or equal to n_samples in the
107-
case of a bootstrap subsample with repeated indices. If None, the
108-
sample weight will be calculated over the full sample. Only "balanced"
109-
is supported for class_weight if this is provided.
116+
`n_samples` in the case of a subsample, or equal to `n_samples` in the
117+
case of a bootstrap subsample with repeated indices. If `None`, the
118+
sample weight will be calculated over the full sample. Only `"balanced"`
119+
is supported for `class_weight` if this is provided.
110120
111121
Returns
112122
-------
113123
sample_weight_vect : ndarray of shape (n_samples,)
114-
Array with sample weights as applied to the original y.
124+
Array with sample weights as applied to the original `y`.
115125
"""
116126

117127
# Ensure y is 2D. Sparse matrices are already 2D.
@@ -121,27 +131,22 @@ def compute_sample_weight(class_weight, y, *, indices=None):
121131
y = np.reshape(y, (-1, 1))
122132
n_outputs = y.shape[1]
123133

124-
if isinstance(class_weight, str):
125-
if class_weight not in ["balanced"]:
126-
raise ValueError(
127-
'The only valid preset for class_weight is "balanced". Given "%s".'
128-
% class_weight
129-
)
130-
elif indices is not None and not isinstance(class_weight, str):
134+
if indices is not None and class_weight != "balanced":
131135
raise ValueError(
132-
'The only valid class_weight for subsampling is "balanced". Given "%s".'
133-
% class_weight
136+
"The only valid class_weight for subsampling is 'balanced'. "
137+
f"Given {class_weight}."
134138
)
135139
elif n_outputs > 1:
136-
if not hasattr(class_weight, "__iter__") or isinstance(class_weight, dict):
140+
if class_weight is None or isinstance(class_weight, dict):
137141
raise ValueError(
138-
"For multi-output, class_weight should be a "
139-
"list of dicts, or a valid string."
142+
"For multi-output, class_weight should be a list of dicts, or the "
143+
"string 'balanced'."
140144
)
141-
if len(class_weight) != n_outputs:
145+
elif isinstance(class_weight, list) and len(class_weight) != n_outputs:
142146
raise ValueError(
143-
"For multi-output, number of elements in "
144-
"class_weight should match number of outputs."
147+
"For multi-output, number of elements in class_weight should match "
148+
f"number of outputs. Got {len(class_weight)} element(s) while having "
149+
f"{n_outputs} outputs."
145150
)
146151

147152
expanded_class_weight = []

sklearn/utils/tests/test_class_weight.py

Lines changed: 29 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -235,32 +235,38 @@ def test_compute_sample_weight_with_subsample():
235235
assert_array_almost_equal(sample_weight, [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0])
236236

237237

238-
def test_compute_sample_weight_errors():
238+
@pytest.mark.parametrize(
239+
"y_type, class_weight, indices, err_msg",
240+
[
241+
(
242+
"single-output",
243+
{1: 2, 2: 1},
244+
range(4),
245+
"The only valid class_weight for subsampling is 'balanced'.",
246+
),
247+
(
248+
"multi-output",
249+
{1: 2, 2: 1},
250+
None,
251+
"For multi-output, class_weight should be a list of dicts, or the string",
252+
),
253+
(
254+
"multi-output",
255+
[{1: 2, 2: 1}],
256+
None,
257+
r"Got 1 element\(s\) while having 2 outputs",
258+
),
259+
],
260+
)
261+
def test_compute_sample_weight_errors(y_type, class_weight, indices, err_msg):
239262
# Test compute_sample_weight raises errors expected.
240263
# Invalid preset string
241-
y = np.asarray([1, 1, 1, 2, 2, 2])
242-
y_ = np.asarray([[1, 0], [1, 0], [1, 0], [2, 1], [2, 1], [2, 1]])
243-
244-
with pytest.raises(ValueError):
245-
compute_sample_weight("ni", y)
246-
with pytest.raises(ValueError):
247-
compute_sample_weight("ni", y, indices=range(4))
248-
with pytest.raises(ValueError):
249-
compute_sample_weight("ni", y_)
250-
with pytest.raises(ValueError):
251-
compute_sample_weight("ni", y_, indices=range(4))
264+
y_single_output = np.asarray([1, 1, 1, 2, 2, 2])
265+
y_multi_output = np.asarray([[1, 0], [1, 0], [1, 0], [2, 1], [2, 1], [2, 1]])
252266

253-
# Not "balanced" for subsample
254-
with pytest.raises(ValueError):
255-
compute_sample_weight({1: 2, 2: 1}, y, indices=range(4))
256-
257-
# Not a list or preset for multi-output
258-
with pytest.raises(ValueError):
259-
compute_sample_weight({1: 2, 2: 1}, y_)
260-
261-
# Incorrect length list for multi-output
262-
with pytest.raises(ValueError):
263-
compute_sample_weight([{1: 2, 2: 1}], y_)
267+
y = y_single_output if y_type == "single-output" else y_multi_output
268+
with pytest.raises(ValueError, match=err_msg):
269+
compute_sample_weight(class_weight, y, indices=indices)
264270

265271

266272
def test_compute_sample_weight_more_than_32():

0 commit comments

Comments
 (0)