5
5
import numpy as np
6
6
from scipy import sparse
7
7
8
+ from ._param_validation import StrOptions , validate_params
9
+
8
10
9
11
def compute_class_weight (class_weight , * , classes , y ):
10
12
"""Estimate class weights for unbalanced datasets.
@@ -75,26 +77,34 @@ def compute_class_weight(class_weight, *, classes, y):
75
77
return weight
76
78
77
79
80
+ @validate_params (
81
+ {
82
+ "class_weight" : [dict , list , StrOptions ({"balanced" }), None ],
83
+ "y" : ["array-like" , "sparse matrix" ],
84
+ "indices" : ["array-like" , None ],
85
+ },
86
+ prefer_skip_nested_validation = True ,
87
+ )
78
88
def compute_sample_weight (class_weight , y , * , indices = None ):
79
89
"""Estimate sample weights by class for unbalanced datasets.
80
90
81
91
Parameters
82
92
----------
83
93
class_weight : dict, list of dicts, "balanced", or None
84
- Weights associated with classes in the form `` {class_label: weight}` `.
94
+ Weights associated with classes in the form `{class_label: weight}`.
85
95
If not given, all classes are supposed to have weight one. For
86
96
multi-output problems, a list of dicts can be provided in the same
87
97
order as the columns of y.
88
98
89
99
Note that for multioutput (including multilabel) weights should be
90
100
defined for each class of every column in its own dict. For example,
91
101
for four-class multilabel classification weights should be
92
- [{0: 1, 1: 1}, {0: 1, 1: 5}, {0: 1, 1: 1}, {0: 1, 1: 1}] instead of
93
- [{1:1}, {2:5}, {3:1}, {4:1}].
102
+ ` [{0: 1, 1: 1}, {0: 1, 1: 5}, {0: 1, 1: 1}, {0: 1, 1: 1}]` instead of
103
+ ` [{1:1}, {2:5}, {3:1}, {4:1}]` .
94
104
95
- The "balanced" mode uses the values of y to automatically adjust
105
+ The ` "balanced"` mode uses the values of y to automatically adjust
96
106
weights inversely proportional to class frequencies in the input data:
97
- `` n_samples / (n_classes * np.bincount(y))` `.
107
+ `n_samples / (n_classes * np.bincount(y))`.
98
108
99
109
For multi-output, the weights of each column of y will be multiplied.
100
110
@@ -103,15 +113,15 @@ def compute_sample_weight(class_weight, y, *, indices=None):
103
113
104
114
indices : array-like of shape (n_subsample,), default=None
105
115
Array of indices to be used in a subsample. Can be of length less than
106
- n_samples in the case of a subsample, or equal to n_samples in the
107
- case of a bootstrap subsample with repeated indices. If None, the
108
- sample weight will be calculated over the full sample. Only "balanced"
109
- is supported for class_weight if this is provided.
116
+ ` n_samples` in the case of a subsample, or equal to ` n_samples` in the
117
+ case of a bootstrap subsample with repeated indices. If ` None` , the
118
+ sample weight will be calculated over the full sample. Only ` "balanced"`
119
+ is supported for ` class_weight` if this is provided.
110
120
111
121
Returns
112
122
-------
113
123
sample_weight_vect : ndarray of shape (n_samples,)
114
- Array with sample weights as applied to the original y .
124
+ Array with sample weights as applied to the original `y` .
115
125
"""
116
126
117
127
# Ensure y is 2D. Sparse matrices are already 2D.
@@ -121,27 +131,22 @@ def compute_sample_weight(class_weight, y, *, indices=None):
121
131
y = np .reshape (y , (- 1 , 1 ))
122
132
n_outputs = y .shape [1 ]
123
133
124
- if isinstance (class_weight , str ):
125
- if class_weight not in ["balanced" ]:
126
- raise ValueError (
127
- 'The only valid preset for class_weight is "balanced". Given "%s".'
128
- % class_weight
129
- )
130
- elif indices is not None and not isinstance (class_weight , str ):
134
+ if indices is not None and class_weight != "balanced" :
131
135
raise ValueError (
132
- ' The only valid class_weight for subsampling is " balanced". Given "%s".'
133
- % class_weight
136
+ " The only valid class_weight for subsampling is ' balanced'. "
137
+ f"Given { class_weight } ."
134
138
)
135
139
elif n_outputs > 1 :
136
- if not hasattr ( class_weight , "__iter__" ) or isinstance (class_weight , dict ):
140
+ if class_weight is None or isinstance (class_weight , dict ):
137
141
raise ValueError (
138
- "For multi-output, class_weight should be a "
139
- "list of dicts, or a valid string."
142
+ "For multi-output, class_weight should be a list of dicts, or the "
143
+ "string 'balanced' ."
140
144
)
141
- if len (class_weight ) != n_outputs :
145
+ elif isinstance ( class_weight , list ) and len (class_weight ) != n_outputs :
142
146
raise ValueError (
143
- "For multi-output, number of elements in "
144
- "class_weight should match number of outputs."
147
+ "For multi-output, number of elements in class_weight should match "
148
+ f"number of outputs. Got { len (class_weight )} element(s) while having "
149
+ f"{ n_outputs } outputs."
145
150
)
146
151
147
152
expanded_class_weight = []
0 commit comments