Skip to content

Commit 3c9495c

Browse files
glemaitrethomasjpfanjeremiedbbStefanieSengerVladimirFokow
authored
FIX accept multilabel-indicator in _get_response_values (scikit-learn#27002)
Co-authored-by: Thomas J. Fan <thomasjpfan@gmail.com> Co-authored-by: Jérémie du Boisberranger <34657725+jeremiedbb@users.noreply.github.com> Co-authored-by: Stefanie Senger <91849487+StefanieSenger@users.noreply.github.com> Co-authored-by: Vladimir Fokow <57260995+VladimirFokow@users.noreply.github.com> Co-authored-by: jeremie du boisberranger <jeremiedbb@yahoo.fr> Co-authored-by: Xuefeng Xu <xuxf100@qq.com> Co-authored-by: Tim Head <betatim@gmail.com> Co-authored-by: Raphael <raphael1peer@gmail.com> Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org> Co-authored-by: Sayed Qaiser Ali <66676360+sqali@users.noreply.github.com> Co-authored-by: Loïc Estève <loic.esteve@ymail.com> Co-authored-by: Xiao Yuan <yuanx749@gmail.com>
1 parent a337410 commit 3c9495c

File tree

4 files changed

+208
-34
lines changed

4 files changed

+208
-34
lines changed

doc/whats_new/v1.3.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,13 @@ Changelog
4343
``transform`` when ``add_indicator`` is set to ``True`` and missing values are observed
4444
during ``fit``. :pr:`26600` by :user:`Shreesha Kumar Bhat <Shreesha3112>`.
4545

46+
:mod:`sklearn.metrics`
47+
.......................
48+
49+
- |Fix| Scorers used with :func:`metrics.get_scorer` handle properly
50+
multilabel-indicator matrix.
51+
:pr:`27002` by :user:`Guillaume Lemaitre <glemaitre>`.
52+
4653
:mod:`sklearn.neighbors`
4754
........................
4855

sklearn/metrics/tests/test_score_objects.py

Lines changed: 37 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -542,12 +542,15 @@ def test_thresholded_scorers_multilabel_indicator_data():
542542

543543
# Multi-output multi-class decision_function
544544
# TODO Is there any yet?
545-
clf = DecisionTreeClassifier()
546-
clf.fit(X_train, y_train)
547-
clf._predict_proba = clf.predict_proba
548-
clf.predict_proba = None
549-
clf.decision_function = lambda X: [p[:, 1] for p in clf._predict_proba(X)]
545+
class TreeWithDecisionFunction(DecisionTreeClassifier):
546+
# disable predict_proba
547+
predict_proba = None
548+
549+
def decision_function(self, X):
550+
return [p[:, 1] for p in DecisionTreeClassifier.predict_proba(self, X)]
550551

552+
clf = TreeWithDecisionFunction()
553+
clf.fit(X_train, y_train)
551554
y_proba = clf.decision_function(X_test)
552555
score1 = get_scorer("roc_auc")(clf, X_test, y_test)
553556
score2 = roc_auc_score(y_test, np.vstack([p for p in y_proba]).T)
@@ -800,7 +803,19 @@ def test_multimetric_scorer_calls_method_once(
800803
assert decision_function_func.call_count == expected_decision_func_count
801804

802805

803-
def test_multimetric_scorer_calls_method_once_classifier_no_decision():
806+
@pytest.mark.parametrize(
807+
"scorers",
808+
[
809+
(["roc_auc", "neg_log_loss"]),
810+
(
811+
{
812+
"roc_auc": make_scorer(roc_auc_score, needs_threshold=True),
813+
"neg_log_loss": make_scorer(log_loss, needs_proba=True),
814+
}
815+
),
816+
],
817+
)
818+
def test_multimetric_scorer_calls_method_once_classifier_no_decision(scorers):
804819
predict_proba_call_cnt = 0
805820

806821
class MockKNeighborsClassifier(KNeighborsClassifier):
@@ -815,7 +830,6 @@ def predict_proba(self, X):
815830
clf = MockKNeighborsClassifier(n_neighbors=1)
816831
clf.fit(X, y)
817832

818-
scorers = ["roc_auc", "neg_log_loss"]
819833
scorer_dict = _check_multimetric_scoring(clf, scorers)
820834
scorer = _MultimetricScorer(scorers=scorer_dict)
821835
scorer(clf, X, y)
@@ -838,7 +852,7 @@ def predict(self, X):
838852
clf = MockDecisionTreeRegressor()
839853
clf.fit(X, y)
840854

841-
scorers = {"neg_mse": "neg_mean_squared_error", "r2": "roc_auc"}
855+
scorers = {"neg_mse": "neg_mean_squared_error", "r2": "r2"}
842856
scorer_dict = _check_multimetric_scoring(clf, scorers)
843857
scorer = _MultimetricScorer(scorers=scorer_dict)
844858
scorer(clf, X, y)
@@ -1357,3 +1371,18 @@ def score(y_true, y_pred, param=None):
13571371
ValueError, match="is only supported if enable_metadata_routing=True"
13581372
):
13591373
scorer(clf, X, y, param="blah")
1374+
1375+
1376+
def test_get_scorer_multilabel_indicator():
1377+
"""Check that our scorer deal with multi-label indicator matrices.
1378+
1379+
Non-regression test for:
1380+
https://github.com/scikit-learn/scikit-learn/issues/26817
1381+
"""
1382+
X, Y = make_multilabel_classification(n_samples=72, n_classes=3, random_state=0)
1383+
X_train, X_test, Y_train, Y_test = train_test_split(X, Y, random_state=0)
1384+
1385+
estimator = KNeighborsClassifier().fit(X_train, Y_train)
1386+
1387+
score = get_scorer("average_precision")(estimator, X_test, Y_test)
1388+
assert score > 0.8

sklearn/utils/_response.py

Lines changed: 134 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,110 @@
55
import numpy as np
66

77
from ..base import is_classifier
8+
from .multiclass import type_of_target
89
from .validation import _check_response_method, check_is_fitted
910

1011

12+
def _process_predict_proba(*, y_pred, target_type, classes, pos_label):
13+
"""Get the response values when the response method is `predict_proba`.
14+
15+
This function process the `y_pred` array in the binary and multi-label cases.
16+
In the binary case, it selects the column corresponding to the positive
17+
class. In the multi-label case, it stacks the predictions if they are not
18+
in the "compressed" format `(n_samples, n_outputs)`.
19+
20+
Parameters
21+
----------
22+
y_pred : ndarray
23+
Output of `estimator.predict_proba`. The shape depends on the target type:
24+
25+
- for binary classification, it is a 2d array of shape `(n_samples, 2)`;
26+
- for multiclass classification, it is a 2d array of shape
27+
`(n_samples, n_classes)`;
28+
- for multilabel classification, it is either a list of 2d arrays of shape
29+
`(n_samples, 2)` (e.g. `RandomForestClassifier` or `KNeighborsClassifier`) or
30+
an array of shape `(n_samples, n_outputs)` (e.g. `MLPClassifier` or
31+
`RidgeClassifier`).
32+
33+
target_type : {"binary", "multiclass", "multilabel-indicator"}
34+
Type of the target.
35+
36+
classes : ndarray of shape (n_classes,) or list of such arrays
37+
Class labels as reported by `estimator.classes_`.
38+
39+
pos_label : int, float, bool or str
40+
Only used with binary and multiclass targets.
41+
42+
Returns
43+
-------
44+
y_pred : ndarray of shape (n_samples,), (n_samples, n_classes) or \
45+
(n_samples, n_output)
46+
Compressed predictions format as requested by the metrics.
47+
"""
48+
if target_type == "binary" and y_pred.shape[1] < 2:
49+
# We don't handle classifiers trained on a single class.
50+
raise ValueError(
51+
f"Got predict_proba of shape {y_pred.shape}, but need "
52+
"classifier with two classes."
53+
)
54+
55+
if target_type == "binary":
56+
col_idx = np.flatnonzero(classes == pos_label)[0]
57+
return y_pred[:, col_idx]
58+
elif target_type == "multilabel-indicator":
59+
# Use a compress format of shape `(n_samples, n_output)`.
60+
# Only `MLPClassifier` and `RidgeClassifier` return an array of shape
61+
# `(n_samples, n_outputs)`.
62+
if isinstance(y_pred, list):
63+
# list of arrays of shape `(n_samples, 2)`
64+
return np.vstack([p[:, -1] for p in y_pred]).T
65+
else:
66+
# array of shape `(n_samples, n_outputs)`
67+
return y_pred
68+
69+
return y_pred
70+
71+
72+
def _process_decision_function(*, y_pred, target_type, classes, pos_label):
73+
"""Get the response values when the response method is `decision_function`.
74+
75+
This function process the `y_pred` array in the binary and multi-label cases.
76+
In the binary case, it inverts the sign of the score if the positive label
77+
is not `classes[1]`. In the multi-label case, it stacks the predictions if
78+
they are not in the "compressed" format `(n_samples, n_outputs)`.
79+
80+
Parameters
81+
----------
82+
y_pred : ndarray
83+
Output of `estimator.predict_proba`. The shape depends on the target type:
84+
85+
- for binary classification, it is a 1d array of shape `(n_samples,)` where the
86+
sign is assuming that `classes[1]` is the positive class;
87+
- for multiclass classification, it is a 2d array of shape
88+
`(n_samples, n_classes)`;
89+
- for multilabel classification, it is a 2d array of shape `(n_samples,
90+
n_outputs)`.
91+
92+
target_type : {"binary", "multiclass", "multilabel-indicator"}
93+
Type of the target.
94+
95+
classes : ndarray of shape (n_classes,) or list of such arrays
96+
Class labels as reported by `estimator.classes_`.
97+
98+
pos_label : int, float, bool or str
99+
Only used with binary and multiclass targets.
100+
101+
Returns
102+
-------
103+
y_pred : ndarray of shape (n_samples,), (n_samples, n_classes) or \
104+
(n_samples, n_output)
105+
Compressed predictions format as requested by the metrics.
106+
"""
107+
if target_type == "binary" and pos_label == classes[0]:
108+
return -1 * y_pred
109+
return y_pred
110+
111+
11112
def _get_response_values(
12113
estimator,
13114
X,
@@ -16,12 +117,18 @@ def _get_response_values(
16117
):
17118
"""Compute the response values of a classifier or a regressor.
18119
19-
The response values are predictions, one scalar value for each sample in X
20-
that depends on the specific choice of `response_method`.
120+
The response values are predictions such that it follows the following shape:
121+
122+
- for binary classification, it is a 1d array of shape `(n_samples,)`;
123+
- for multiclass classification, it is a 2d array of shape `(n_samples, n_classes)`;
124+
- for multilabel classification, it is a 2d array of shape `(n_samples, n_outputs)`;
125+
- for regression, it is a 1d array of shape `(n_samples,)`.
21126
22127
If `estimator` is a binary classifier, also return the label for the
23128
effective positive class.
24129
130+
This utility is used primarily in the displays and the scikit-learn scorers.
131+
25132
.. versionadded:: 1.3
26133
27134
Parameters
@@ -51,8 +158,9 @@ def _get_response_values(
51158
52159
Returns
53160
-------
54-
y_pred : ndarray of shape (n_samples,)
55-
Target scores calculated from the provided response_method
161+
y_pred : ndarray of shape (n_samples,), (n_samples, n_classes) or \
162+
(n_samples, n_outputs)
163+
Target scores calculated from the provided `response_method`
56164
and `pos_label`.
57165
58166
pos_label : int, float, bool, str or None
@@ -72,32 +180,33 @@ def _get_response_values(
72180
if is_classifier(estimator):
73181
prediction_method = _check_response_method(estimator, response_method)
74182
classes = estimator.classes_
75-
target_type = "binary" if len(classes) <= 2 else "multiclass"
183+
target_type = type_of_target(classes)
76184

77-
if pos_label is not None and pos_label not in classes.tolist():
78-
raise ValueError(
79-
f"pos_label={pos_label} is not a valid label: It should be "
80-
f"one of {classes}"
81-
)
82-
elif pos_label is None and target_type == "binary":
83-
pos_label = pos_label if pos_label is not None else classes[-1]
185+
if target_type in ("binary", "multiclass"):
186+
if pos_label is not None and pos_label not in classes.tolist():
187+
raise ValueError(
188+
f"pos_label={pos_label} is not a valid label: It should be "
189+
f"one of {classes}"
190+
)
191+
elif pos_label is None and target_type == "binary":
192+
pos_label = classes[-1]
84193

85194
y_pred = prediction_method(X)
195+
86196
if prediction_method.__name__ == "predict_proba":
87-
if target_type == "binary" and y_pred.shape[1] <= 2:
88-
if y_pred.shape[1] == 2:
89-
col_idx = np.flatnonzero(classes == pos_label)[0]
90-
y_pred = y_pred[:, col_idx]
91-
else:
92-
err_msg = (
93-
f"Got predict_proba of shape {y_pred.shape}, but need "
94-
"classifier with two classes."
95-
)
96-
raise ValueError(err_msg)
197+
y_pred = _process_predict_proba(
198+
y_pred=y_pred,
199+
target_type=target_type,
200+
classes=classes,
201+
pos_label=pos_label,
202+
)
97203
elif prediction_method.__name__ == "decision_function":
98-
if target_type == "binary":
99-
if pos_label == classes[0]:
100-
y_pred *= -1
204+
y_pred = _process_decision_function(
205+
y_pred=y_pred,
206+
target_type=target_type,
207+
classes=classes,
208+
pos_label=pos_label,
209+
)
101210
else: # estimator is a regressor
102211
if response_method != "predict":
103212
raise ValueError(

sklearn/utils/tests/test_response.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,17 @@
11
import numpy as np
22
import pytest
33

4-
from sklearn.datasets import load_iris, make_classification, make_regression
4+
from sklearn.datasets import (
5+
load_iris,
6+
make_classification,
7+
make_multilabel_classification,
8+
make_regression,
9+
)
510
from sklearn.linear_model import (
611
LinearRegression,
712
LogisticRegression,
813
)
14+
from sklearn.multioutput import ClassifierChain
915
from sklearn.preprocessing import scale
1016
from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor
1117
from sklearn.utils._mocking import _MockEstimatorOnOffPrediction
@@ -230,3 +236,26 @@ def test_get_response_values_multiclass(estimator, response_method):
230236
assert predictions.shape == (X.shape[0], len(estimator.classes_))
231237
if response_method == "predict_proba":
232238
assert np.logical_and(predictions >= 0, predictions <= 1).all()
239+
240+
241+
@pytest.mark.parametrize(
242+
"response_method", ["predict_proba", "decision_function", "predict"]
243+
)
244+
def test_get_response_values_multilabel_indicator(response_method):
245+
X, Y = make_multilabel_classification(random_state=0)
246+
estimator = ClassifierChain(LogisticRegression()).fit(X, Y)
247+
248+
y_pred, pos_label = _get_response_values(
249+
estimator, X, response_method=response_method
250+
)
251+
assert pos_label is None
252+
assert y_pred.shape == Y.shape
253+
254+
if response_method == "predict_proba":
255+
assert np.logical_and(y_pred >= 0, y_pred <= 1).all()
256+
elif response_method == "decision_function":
257+
# values returned by `decision_function` are not bounded in [0, 1]
258+
assert (y_pred < 0).sum() > 0
259+
assert (y_pred > 1).sum() > 0
260+
else: # response_method == "predict"
261+
assert np.logical_or(y_pred == 0, y_pred == 1).all()

0 commit comments

Comments
 (0)