Skip to content

Commit 8e9cd7d

Browse files
OmarManzoorogrisel
andauthored
FIX CalibratedClassifierCV with sigmoid and large confidence scores (scikit-learn#26913)
Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
1 parent 6c18131 commit 8e9cd7d

File tree

3 files changed

+125
-3
lines changed

3 files changed

+125
-3
lines changed

doc/whats_new/v1.4.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,13 @@ Changelog
7878
and all metadata are passed as keyword arguments. :pr:`26909` by `Adrin
7979
Jalali`_.
8080

81+
:mod:`sklearn.calibration`
82+
..........................
83+
84+
- |Fix| :class:`calibration.CalibratedClassifierCV` can now handle models that
85+
produce large prediction scores. Before it was numerically unstable.
86+
:pr:`26913` by :user:`Omar Salman <OmarManzoor>`.
87+
8188
:mod:`sklearn.cluster`
8289
............................
8390

sklearn/calibration.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -823,7 +823,11 @@ def predict_proba(self, X):
823823
return proba
824824

825825

826-
def _sigmoid_calibration(predictions, y, sample_weight=None):
826+
# The max_abs_prediction_threshold was approximated using
827+
# logit(np.finfo(np.float64).eps) which is about -36
828+
def _sigmoid_calibration(
829+
predictions, y, sample_weight=None, max_abs_prediction_threshold=30
830+
):
827831
"""Probability Calibration with sigmoid method (Platt 2000)
828832
829833
Parameters
@@ -854,6 +858,20 @@ def _sigmoid_calibration(predictions, y, sample_weight=None):
854858

855859
F = predictions # F follows Platt's notations
856860

861+
scale_constant = 1.0
862+
max_prediction = np.max(np.abs(F))
863+
864+
# If the predictions have large values we scale them in order to bring
865+
# them within a suitable range. This has no effect on the final
866+
# (prediction) result because linear models like Logisitic Regression
867+
# without a penalty are invariant to multiplying the features by a
868+
# constant.
869+
if max_prediction >= max_abs_prediction_threshold:
870+
scale_constant = max_prediction
871+
# We rescale the features in a copy: inplace rescaling could confuse
872+
# the caller and make the code harder to reason about.
873+
F = F / scale_constant
874+
857875
# Bayesian priors (see Platt end of section 2.2):
858876
# It corresponds to the number of samples, taking into account the
859877
# `sample_weight`.
@@ -890,7 +908,11 @@ def grad(AB):
890908

891909
AB0 = np.array([0.0, log((prior0 + 1.0) / (prior1 + 1.0))])
892910
AB_ = fmin_bfgs(objective, AB0, fprime=grad, disp=False)
893-
return AB_[0], AB_[1]
911+
912+
# The tuned multiplicative parameter is converted back to the original
913+
# input feature scale. The offset parameter does not need rescaling since
914+
# we did not rescale the outcome variable.
915+
return AB_[0] / scale_constant, AB_[1]
894916

895917

896918
class _SigmoidCalibration(RegressorMixin, BaseEstimator):

sklearn/tests/test_calibration.py

Lines changed: 94 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,14 @@
2525
from sklearn.feature_extraction import DictVectorizer
2626
from sklearn.impute import SimpleImputer
2727
from sklearn.isotonic import IsotonicRegression
28-
from sklearn.linear_model import LogisticRegression
28+
from sklearn.linear_model import LogisticRegression, SGDClassifier
2929
from sklearn.metrics import brier_score_loss
3030
from sklearn.model_selection import (
3131
KFold,
3232
LeaveOneOut,
33+
check_cv,
3334
cross_val_predict,
35+
cross_val_score,
3436
train_test_split,
3537
)
3638
from sklearn.naive_bayes import MultinomialNB
@@ -996,3 +998,94 @@ def fit(self, X, y, sample_weight=None, fit_param=None):
996998
CalibratedClassifierCV(estimator=TestClassifier()).fit(
997999
*data, fit_param=np.ones(len(data[1]) + 1)
9981000
)
1001+
1002+
1003+
def test_calibrated_classifier_cv_works_with_large_confidence_scores(
1004+
global_random_seed,
1005+
):
1006+
"""Test that :class:`CalibratedClassifierCV` works with large confidence
1007+
scores when using the `sigmoid` method, particularly with the
1008+
:class:`SGDClassifier`.
1009+
1010+
Non-regression test for issue #26766.
1011+
"""
1012+
prob = 0.67
1013+
n = 1000
1014+
random_noise = np.random.default_rng(global_random_seed).normal(size=n)
1015+
1016+
y = np.array([1] * int(n * prob) + [0] * (n - int(n * prob)))
1017+
X = 1e5 * y.reshape((-1, 1)) + random_noise
1018+
1019+
# Check that the decision function of SGDClassifier produces predicted
1020+
# values that are quite large, for the data under consideration.
1021+
cv = check_cv(cv=None, y=y, classifier=True)
1022+
indices = cv.split(X, y)
1023+
for train, test in indices:
1024+
X_train, y_train = X[train], y[train]
1025+
X_test = X[test]
1026+
sgd_clf = SGDClassifier(loss="squared_hinge", random_state=global_random_seed)
1027+
sgd_clf.fit(X_train, y_train)
1028+
predictions = sgd_clf.decision_function(X_test)
1029+
assert (predictions > 1e4).any()
1030+
1031+
# Compare the CalibratedClassifierCV using the sigmoid method with the
1032+
# CalibratedClassifierCV using the isotonic method. The isotonic method
1033+
# is used for comparison because it is numerically stable.
1034+
clf_sigmoid = CalibratedClassifierCV(
1035+
SGDClassifier(loss="squared_hinge", random_state=global_random_seed),
1036+
method="sigmoid",
1037+
)
1038+
score_sigmoid = cross_val_score(clf_sigmoid, X, y, scoring="roc_auc")
1039+
1040+
# The isotonic method is used for comparison because it is numerically
1041+
# stable.
1042+
clf_isotonic = CalibratedClassifierCV(
1043+
SGDClassifier(loss="squared_hinge", random_state=global_random_seed),
1044+
method="isotonic",
1045+
)
1046+
score_isotonic = cross_val_score(clf_isotonic, X, y, scoring="roc_auc")
1047+
1048+
# The AUC score should be the same because it is invariant under
1049+
# strictly monotonic conditions
1050+
assert_allclose(score_sigmoid, score_isotonic)
1051+
1052+
1053+
def test_sigmoid_calibration_max_abs_prediction_threshold(global_random_seed):
1054+
random_state = np.random.RandomState(seed=global_random_seed)
1055+
n = 100
1056+
y = random_state.randint(0, 2, size=n)
1057+
1058+
# Check that for small enough predictions ranging from -2 to 2, the
1059+
# threshold value has no impact on the outcome
1060+
predictions_small = random_state.uniform(low=-2, high=2, size=100)
1061+
1062+
# Using a threshold lower than the maximum absolute value of the
1063+
# predictions enables internal re-scaling by max(abs(predictions_small)).
1064+
threshold_1 = 0.1
1065+
a1, b1 = _sigmoid_calibration(
1066+
predictions=predictions_small,
1067+
y=y,
1068+
max_abs_prediction_threshold=threshold_1,
1069+
)
1070+
1071+
# Using a larger threshold disables rescaling.
1072+
threshold_2 = 10
1073+
a2, b2 = _sigmoid_calibration(
1074+
predictions=predictions_small,
1075+
y=y,
1076+
max_abs_prediction_threshold=threshold_2,
1077+
)
1078+
1079+
# Using default threshold of 30 also disables the scaling.
1080+
a3, b3 = _sigmoid_calibration(
1081+
predictions=predictions_small,
1082+
y=y,
1083+
)
1084+
1085+
# Depends on the tolerance of the underlying quasy-newton solver which is
1086+
# not too strict by default.
1087+
atol = 1e-6
1088+
assert_allclose(a1, a2, atol=atol)
1089+
assert_allclose(a2, a3, atol=atol)
1090+
assert_allclose(b1, b2, atol=atol)
1091+
assert_allclose(b2, b3, atol=atol)

0 commit comments

Comments
 (0)