Skip to content

Commit fa8c15f

Browse files
authored
FIX unintentional sample_weight upcast in CalibratedClassifierCV (scikit-learn#30873)
1 parent 243d61a commit fa8c15f

File tree

3 files changed

+69
-12
lines changed

3 files changed

+69
-12
lines changed
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
- :class:`~calibration.CalibratedClassifierCV` now raises `FutureWarning`
2+
instead of `UserWarning` when passing `cv="prefit`". By
3+
:user:`Olivier Grisel <ogrisel>`
4+
- :class:`~calibration.CalibratedClassifierCV` with `method="sigmoid"` no
5+
longer crashes when passing `float64`-dtyped `sample_weight` along with a
6+
base estimator that outputs `float32`-dtyped predictions. By :user:`Olivier
7+
Grisel <ogrisel>`

sklearn/calibration.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -318,9 +318,6 @@ def fit(self, X, y, sample_weight=None, **fit_params):
318318
"""
319319
check_classification_targets(y)
320320
X, y = indexable(X, y)
321-
if sample_weight is not None:
322-
sample_weight = _check_sample_weight(sample_weight, X)
323-
324321
estimator = self._get_estimator()
325322

326323
_ensemble = self.ensemble
@@ -333,7 +330,8 @@ def fit(self, X, y, sample_weight=None, **fit_params):
333330
warnings.warn(
334331
"The `cv='prefit'` option is deprecated in 1.6 and will be removed in"
335332
" 1.8. You can use CalibratedClassifierCV(FrozenEstimator(estimator))"
336-
" instead."
333+
" instead.",
334+
category=FutureWarning,
337335
)
338336
# `classes_` should be consistent with that of estimator
339337
check_is_fitted(self.estimator, attributes=["classes_"])
@@ -348,6 +346,13 @@ def fit(self, X, y, sample_weight=None, **fit_params):
348346
# Reshape binary output from `(n_samples,)` to `(n_samples, 1)`
349347
predictions = predictions.reshape(-1, 1)
350348

349+
if sample_weight is not None:
350+
# Check that the sample_weight dtype is consistent with the predictions
351+
# to avoid unintentional upcasts.
352+
sample_weight = _check_sample_weight(
353+
sample_weight, predictions, dtype=predictions.dtype
354+
)
355+
351356
calibrated_classifier = _fit_calibrator(
352357
estimator,
353358
predictions,
@@ -457,6 +462,13 @@ def fit(self, X, y, sample_weight=None, **fit_params):
457462
)
458463
predictions = predictions.reshape(-1, 1)
459464

465+
if sample_weight is not None:
466+
# Check that the sample_weight dtype is consistent with the
467+
# predictions to avoid unintentional upcasts.
468+
sample_weight = _check_sample_weight(
469+
sample_weight, predictions, dtype=predictions.dtype
470+
)
471+
460472
this_estimator.fit(X, y, **routed_params.estimator.fit)
461473
# Note: Here we don't pass on fit_params because the supported
462474
# calibrators don't support fit_params anyway
@@ -622,7 +634,13 @@ def _fit_classifier_calibrator_pair(
622634
# Reshape binary output from `(n_samples,)` to `(n_samples, 1)`
623635
predictions = predictions.reshape(-1, 1)
624636

625-
sw_test = None if sample_weight is None else _safe_indexing(sample_weight, test)
637+
if sample_weight is not None:
638+
# Check that the sample_weight dtype is consistent with the predictions
639+
# to avoid unintentional upcasts.
640+
sample_weight = _check_sample_weight(sample_weight, X, dtype=predictions.dtype)
641+
sw_test = _safe_indexing(sample_weight, test)
642+
else:
643+
sw_test = None
626644
calibrated_classifier = _fit_calibrator(
627645
estimator, predictions, y_test, classes, method, sample_weight=sw_test
628646
)

sklearn/tests/test_calibration.py

Lines changed: 39 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -579,8 +579,12 @@ def test_calibration_attributes(clf, cv):
579579
X, y = make_classification(n_samples=10, n_features=5, n_classes=2, random_state=7)
580580
if cv == "prefit":
581581
clf = clf.fit(X, y)
582-
calib_clf = CalibratedClassifierCV(clf, cv=cv)
583-
calib_clf.fit(X, y)
582+
calib_clf = CalibratedClassifierCV(clf, cv=cv)
583+
with pytest.warns(FutureWarning):
584+
calib_clf.fit(X, y)
585+
else:
586+
calib_clf = CalibratedClassifierCV(clf, cv=cv)
587+
calib_clf.fit(X, y)
584588

585589
if cv == "prefit":
586590
assert_array_equal(calib_clf.classes_, clf.classes_)
@@ -1077,20 +1081,48 @@ def test_sigmoid_calibration_max_abs_prediction_threshold(global_random_seed):
10771081
assert_allclose(b2, b3, atol=atol)
10781082

10791083

1080-
def test_float32_predict_proba(data):
1084+
@pytest.mark.parametrize("use_sample_weight", [True, False])
1085+
@pytest.mark.parametrize("method", ["sigmoid", "isotonic"])
1086+
def test_float32_predict_proba(data, use_sample_weight, method):
10811087
"""Check that CalibratedClassifierCV works with float32 predict proba.
10821088
1083-
Non-regression test for gh-28245.
1089+
Non-regression test for gh-28245 and gh-28247.
10841090
"""
1091+
if use_sample_weight:
1092+
# Use dtype=np.float64 to check that this does not trigger an
1093+
# unintentional upcasting: the dtype of the base estimator should
1094+
# control the dtype of the final model. In particular, the
1095+
# sigmoid calibrator relies on inputs (predictions and sample weights)
1096+
# with consistent dtypes because it is partially written in Cython.
1097+
# As this test forces the predictions to be `float32`, we want to check
1098+
# that `CalibratedClassifierCV` internally converts `sample_weight` to
1099+
# the same dtype to avoid crashing the Cython call.
1100+
sample_weight = np.ones_like(data[1], dtype=np.float64)
1101+
else:
1102+
sample_weight = None
10851103

10861104
class DummyClassifer32(DummyClassifier):
10871105
def predict_proba(self, X):
10881106
return super().predict_proba(X).astype(np.float32)
10891107

10901108
model = DummyClassifer32()
1091-
calibrator = CalibratedClassifierCV(model)
1092-
# Does not raise an error
1093-
calibrator.fit(*data)
1109+
calibrator = CalibratedClassifierCV(model, method=method)
1110+
# Does not raise an error.
1111+
calibrator.fit(*data, sample_weight=sample_weight)
1112+
1113+
# Check with frozen prefit model
1114+
model = DummyClassifer32().fit(*data, sample_weight=sample_weight)
1115+
calibrator = CalibratedClassifierCV(FrozenEstimator(model), method=method)
1116+
# Does not raise an error.
1117+
calibrator.fit(*data, sample_weight=sample_weight)
1118+
1119+
# TODO(1.8): remove me once the deprecation period is over.
1120+
# Check with prefit model using the deprecated cv="prefit" argument:
1121+
model = DummyClassifer32().fit(*data, sample_weight=sample_weight)
1122+
calibrator = CalibratedClassifierCV(model, method=method, cv="prefit")
1123+
# Does not raise an error.
1124+
with pytest.warns(FutureWarning):
1125+
calibrator.fit(*data, sample_weight=sample_weight)
10941126

10951127

10961128
def test_error_less_class_samples_than_folds():

0 commit comments

Comments
 (0)