Skip to content

Commit a5d7f9e

Browse files
Fix OneVsRest predict_proba is all zeros when positive class is never predicted (scikit-learn#31228)
Co-authored-by: Jérémie du Boisberranger <jeremie@probabl.ai>
1 parent f44350d commit a5d7f9e

File tree

3 files changed

+35
-2
lines changed

3 files changed

+35
-2
lines changed
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
- The `predict_proba` method of :class:`sklearn.multiclass.OneVsRestClassifier` now
2+
returns zero for all classes when all inner estimators never predict their positive
3+
class.
4+
By :user:`Luis M. B. Varona <Luis-Varona>`, :user:`Marc Bresson <MarcBresson>`, and
5+
:user:`Jérémie du Boisberranger <jeremiedbb>`.

sklearn/multiclass.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -553,8 +553,10 @@ def predict_proba(self, X):
553553
Y = np.concatenate(((1 - Y), Y), axis=1)
554554

555555
if not self.multilabel_:
556-
# Then, probabilities should be normalized to 1.
557-
Y /= np.sum(Y, axis=1)[:, np.newaxis]
556+
# Then, (nonzero) sample probability distributions should be normalized.
557+
row_sums = np.sum(Y, axis=1)[:, np.newaxis]
558+
np.divide(Y, row_sums, out=Y, where=row_sums != 0)
559+
558560
return Y
559561

560562
@available_if(_estimators_has("decision_function"))

sklearn/tests/test_multiclass.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from numpy.testing import assert_allclose
77

88
from sklearn import datasets, svm
9+
from sklearn.base import BaseEstimator, ClassifierMixin
910
from sklearn.datasets import load_breast_cancer
1011
from sklearn.exceptions import NotFittedError
1112
from sklearn.impute import SimpleImputer
@@ -429,6 +430,31 @@ def test_ovr_single_label_predict_proba():
429430
assert not (pred - Y_pred).any()
430431

431432

433+
def test_ovr_single_label_predict_proba_zero():
434+
"""Check that predic_proba returns all zeros when the base estimator
435+
never predicts the positive class.
436+
"""
437+
438+
class NaiveBinaryClassifier(BaseEstimator, ClassifierMixin):
439+
def fit(self, X, y):
440+
self.classes_ = np.unique(y)
441+
return self
442+
443+
def predict_proba(self, X):
444+
proba = np.ones((len(X), 2))
445+
# Probability of being the positive class is always 0
446+
proba[:, 1] = 0
447+
return proba
448+
449+
base_clf = NaiveBinaryClassifier()
450+
X, y = iris.data, iris.target # Three-class problem with 150 samples
451+
452+
clf = OneVsRestClassifier(base_clf).fit(X, y)
453+
y_proba = clf.predict_proba(X)
454+
455+
assert_allclose(y_proba, 0.0)
456+
457+
432458
def test_ovr_multilabel_decision_function():
433459
X, Y = datasets.make_multilabel_classification(
434460
n_samples=100,

0 commit comments

Comments
 (0)