5
5
import numpy as np
6
6
7
7
from ..base import is_classifier
8
+ from .multiclass import type_of_target
8
9
from .validation import _check_response_method , check_is_fitted
9
10
10
11
12
+ def _process_predict_proba (* , y_pred , target_type , classes , pos_label ):
13
+ """Get the response values when the response method is `predict_proba`.
14
+
15
+ This function process the `y_pred` array in the binary and multi-label cases.
16
+ In the binary case, it selects the column corresponding to the positive
17
+ class. In the multi-label case, it stacks the predictions if they are not
18
+ in the "compressed" format `(n_samples, n_outputs)`.
19
+
20
+ Parameters
21
+ ----------
22
+ y_pred : ndarray
23
+ Output of `estimator.predict_proba`. The shape depends on the target type:
24
+
25
+ - for binary classification, it is a 2d array of shape `(n_samples, 2)`;
26
+ - for multiclass classification, it is a 2d array of shape
27
+ `(n_samples, n_classes)`;
28
+ - for multilabel classification, it is either a list of 2d arrays of shape
29
+ `(n_samples, 2)` (e.g. `RandomForestClassifier` or `KNeighborsClassifier`) or
30
+ an array of shape `(n_samples, n_outputs)` (e.g. `MLPClassifier` or
31
+ `RidgeClassifier`).
32
+
33
+ target_type : {"binary", "multiclass", "multilabel-indicator"}
34
+ Type of the target.
35
+
36
+ classes : ndarray of shape (n_classes,) or list of such arrays
37
+ Class labels as reported by `estimator.classes_`.
38
+
39
+ pos_label : int, float, bool or str
40
+ Only used with binary and multiclass targets.
41
+
42
+ Returns
43
+ -------
44
+ y_pred : ndarray of shape (n_samples,), (n_samples, n_classes) or \
45
+ (n_samples, n_output)
46
+ Compressed predictions format as requested by the metrics.
47
+ """
48
+ if target_type == "binary" and y_pred .shape [1 ] < 2 :
49
+ # We don't handle classifiers trained on a single class.
50
+ raise ValueError (
51
+ f"Got predict_proba of shape { y_pred .shape } , but need "
52
+ "classifier with two classes."
53
+ )
54
+
55
+ if target_type == "binary" :
56
+ col_idx = np .flatnonzero (classes == pos_label )[0 ]
57
+ return y_pred [:, col_idx ]
58
+ elif target_type == "multilabel-indicator" :
59
+ # Use a compress format of shape `(n_samples, n_output)`.
60
+ # Only `MLPClassifier` and `RidgeClassifier` return an array of shape
61
+ # `(n_samples, n_outputs)`.
62
+ if isinstance (y_pred , list ):
63
+ # list of arrays of shape `(n_samples, 2)`
64
+ return np .vstack ([p [:, - 1 ] for p in y_pred ]).T
65
+ else :
66
+ # array of shape `(n_samples, n_outputs)`
67
+ return y_pred
68
+
69
+ return y_pred
70
+
71
+
72
+ def _process_decision_function (* , y_pred , target_type , classes , pos_label ):
73
+ """Get the response values when the response method is `decision_function`.
74
+
75
+ This function process the `y_pred` array in the binary and multi-label cases.
76
+ In the binary case, it inverts the sign of the score if the positive label
77
+ is not `classes[1]`. In the multi-label case, it stacks the predictions if
78
+ they are not in the "compressed" format `(n_samples, n_outputs)`.
79
+
80
+ Parameters
81
+ ----------
82
+ y_pred : ndarray
83
+ Output of `estimator.predict_proba`. The shape depends on the target type:
84
+
85
+ - for binary classification, it is a 1d array of shape `(n_samples,)` where the
86
+ sign is assuming that `classes[1]` is the positive class;
87
+ - for multiclass classification, it is a 2d array of shape
88
+ `(n_samples, n_classes)`;
89
+ - for multilabel classification, it is a 2d array of shape `(n_samples,
90
+ n_outputs)`.
91
+
92
+ target_type : {"binary", "multiclass", "multilabel-indicator"}
93
+ Type of the target.
94
+
95
+ classes : ndarray of shape (n_classes,) or list of such arrays
96
+ Class labels as reported by `estimator.classes_`.
97
+
98
+ pos_label : int, float, bool or str
99
+ Only used with binary and multiclass targets.
100
+
101
+ Returns
102
+ -------
103
+ y_pred : ndarray of shape (n_samples,), (n_samples, n_classes) or \
104
+ (n_samples, n_output)
105
+ Compressed predictions format as requested by the metrics.
106
+ """
107
+ if target_type == "binary" and pos_label == classes [0 ]:
108
+ return - 1 * y_pred
109
+ return y_pred
110
+
111
+
11
112
def _get_response_values (
12
113
estimator ,
13
114
X ,
@@ -16,12 +117,18 @@ def _get_response_values(
16
117
):
17
118
"""Compute the response values of a classifier or a regressor.
18
119
19
- The response values are predictions, one scalar value for each sample in X
20
- that depends on the specific choice of `response_method`.
120
+ The response values are predictions such that it follows the following shape:
121
+
122
+ - for binary classification, it is a 1d array of shape `(n_samples,)`;
123
+ - for multiclass classification, it is a 2d array of shape `(n_samples, n_classes)`;
124
+ - for multilabel classification, it is a 2d array of shape `(n_samples, n_outputs)`;
125
+ - for regression, it is a 1d array of shape `(n_samples,)`.
21
126
22
127
If `estimator` is a binary classifier, also return the label for the
23
128
effective positive class.
24
129
130
+ This utility is used primarily in the displays and the scikit-learn scorers.
131
+
25
132
.. versionadded:: 1.3
26
133
27
134
Parameters
@@ -51,8 +158,9 @@ def _get_response_values(
51
158
52
159
Returns
53
160
-------
54
- y_pred : ndarray of shape (n_samples,)
55
- Target scores calculated from the provided response_method
161
+ y_pred : ndarray of shape (n_samples,), (n_samples, n_classes) or \
162
+ (n_samples, n_outputs)
163
+ Target scores calculated from the provided `response_method`
56
164
and `pos_label`.
57
165
58
166
pos_label : int, float, bool, str or None
@@ -72,32 +180,33 @@ def _get_response_values(
72
180
if is_classifier (estimator ):
73
181
prediction_method = _check_response_method (estimator , response_method )
74
182
classes = estimator .classes_
75
- target_type = "binary" if len (classes ) <= 2 else "multiclass"
183
+ target_type = type_of_target (classes )
76
184
77
- if pos_label is not None and pos_label not in classes .tolist ():
78
- raise ValueError (
79
- f"pos_label={ pos_label } is not a valid label: It should be "
80
- f"one of { classes } "
81
- )
82
- elif pos_label is None and target_type == "binary" :
83
- pos_label = pos_label if pos_label is not None else classes [- 1 ]
185
+ if target_type in ("binary" , "multiclass" ):
186
+ if pos_label is not None and pos_label not in classes .tolist ():
187
+ raise ValueError (
188
+ f"pos_label={ pos_label } is not a valid label: It should be "
189
+ f"one of { classes } "
190
+ )
191
+ elif pos_label is None and target_type == "binary" :
192
+ pos_label = classes [- 1 ]
84
193
85
194
y_pred = prediction_method (X )
195
+
86
196
if prediction_method .__name__ == "predict_proba" :
87
- if target_type == "binary" and y_pred .shape [1 ] <= 2 :
88
- if y_pred .shape [1 ] == 2 :
89
- col_idx = np .flatnonzero (classes == pos_label )[0 ]
90
- y_pred = y_pred [:, col_idx ]
91
- else :
92
- err_msg = (
93
- f"Got predict_proba of shape { y_pred .shape } , but need "
94
- "classifier with two classes."
95
- )
96
- raise ValueError (err_msg )
197
+ y_pred = _process_predict_proba (
198
+ y_pred = y_pred ,
199
+ target_type = target_type ,
200
+ classes = classes ,
201
+ pos_label = pos_label ,
202
+ )
97
203
elif prediction_method .__name__ == "decision_function" :
98
- if target_type == "binary" :
99
- if pos_label == classes [0 ]:
100
- y_pred *= - 1
204
+ y_pred = _process_decision_function (
205
+ y_pred = y_pred ,
206
+ target_type = target_type ,
207
+ classes = classes ,
208
+ pos_label = pos_label ,
209
+ )
101
210
else : # estimator is a regressor
102
211
if response_method != "predict" :
103
212
raise ValueError (
0 commit comments