Skip to content

Commit 18dc863

Browse files
TST check that binary only classifiers fail on multiclass data (scikit-learn#29874)
Co-authored-by: Guillaume Lemaitre <guillaume@probabl.ai>
1 parent 89719ab commit 18dc863

File tree

4 files changed

+115
-12
lines changed

4 files changed

+115
-12
lines changed
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
- :func:`~sklearn.utils.estimator_checks.check_estimator` and
2+
:func:`~sklearn.utils.estimator_checks.parametrize_with_checks` now check and fail if
3+
the classifier has the `tags.classifier_tags.multi_class = False` tag but does not
4+
fail on multi-class data.
5+
By `Adrin Jalali`_.

sklearn/utils/estimator_checks.py

Lines changed: 58 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,9 @@ def _yield_classifier_checks(classifier):
196196
):
197197
yield check_class_weight_balanced_linear_classifier
198198

199+
if not tags.classifier_tags.multi_class:
200+
yield check_classifier_not_supporting_multiclass
201+
199202

200203
@ignore_warnings(category=FutureWarning)
201204
def check_supervised_y_no_nan(name, estimator_orig):
@@ -1206,7 +1209,13 @@ def check_dtype_object(name, estimator_orig):
12061209
if hasattr(estimator, "transform"):
12071210
estimator.transform(X)
12081211

1209-
with raises(Exception, match="Unknown label type", may_pass=True):
1212+
err_msg = (
1213+
"y with unknown label type is passed, but an error with no proper message "
1214+
"is raised. You can use `type_of_target(..., raise_unknown=True)` to check "
1215+
"and raise the right error, or include 'Unknown label type' in the error "
1216+
"message."
1217+
)
1218+
with raises(Exception, match="Unknown label type", may_pass=True, err_msg=err_msg):
12101219
estimator.fit(X, y.astype(object))
12111220

12121221
if not tags.input_tags.string:
@@ -3634,9 +3643,15 @@ def check_classifiers_regression_target(name, estimator_orig):
36343643

36353644
X = _enforce_estimator_tags_X(estimator_orig, X)
36363645
e = clone(estimator_orig)
3637-
msg = "Unknown label type: "
3646+
err_msg = (
3647+
"When a classifier is passed a continuous target, it should raise a ValueError"
3648+
" with a message containing 'Unknown label type: ' or a message indicating that"
3649+
" a continuous target is passed and the message should include the word"
3650+
" 'continuous'"
3651+
)
3652+
msg = "Unknown label type: |continuous"
36383653
if not get_tags(e).no_validation:
3639-
with raises(ValueError, match=msg):
3654+
with raises(ValueError, match=msg, err_msg=err_msg):
36403655
e.fit(X, y)
36413656

36423657

@@ -4737,3 +4752,43 @@ def check_do_not_raise_errors_in_init_or_set_params(name, estimator_orig):
47374752

47384753
# Also do does not raise
47394754
est.set_params(**new_params)
4755+
4756+
4757+
def check_classifier_not_supporting_multiclass(name, estimator_orig):
4758+
"""Check that if the classifier has tags.classifier_tags.multi_class=False,
4759+
then it should raise a ValueError when calling fit with a multiclass dataset.
4760+
4761+
This test is not yielded if the tag is not False.
4762+
"""
4763+
estimator = clone(estimator_orig)
4764+
set_random_state(estimator)
4765+
4766+
X, y = make_classification(
4767+
n_samples=100,
4768+
n_classes=3,
4769+
n_informative=3,
4770+
n_clusters_per_class=1,
4771+
random_state=0,
4772+
)
4773+
err_msg = """\
4774+
The estimator tag `tags.classifier_tags.multi_class` is False for {name}
4775+
which means it does not support multiclass classification. However, it does
4776+
not raise the right `ValueError` when calling fit with a multiclass dataset,
4777+
including the error message 'Only binary classification is supported.' This
4778+
can be achieved by the following pattern:
4779+
4780+
y_type = type_of_target(y, input_name='y', raise_unknown=True)
4781+
if y_type != 'binary':
4782+
raise ValueError(
4783+
'Only binary classification is supported. The type of the target '
4784+
f'is {{y_type}}.'
4785+
)
4786+
""".format(
4787+
name=name
4788+
)
4789+
err_msg = textwrap.dedent(err_msg)
4790+
4791+
with raises(
4792+
ValueError, match="Only binary classification is supported.", err_msg=err_msg
4793+
):
4794+
estimator.fit(X, y)

sklearn/utils/multiclass.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,7 @@ def check_classification_targets(y):
226226
)
227227

228228

229-
def type_of_target(y, input_name=""):
229+
def type_of_target(y, input_name="", raise_unknown=False):
230230
"""Determine the type of data indicated by the target.
231231
232232
Note that this type is the most specific type that can be inferred.
@@ -248,6 +248,12 @@ def type_of_target(y, input_name=""):
248248
249249
.. versionadded:: 1.1.0
250250
251+
raise_unknown : bool, default=False
252+
If `True`, raise an error when the type of target returned by
253+
:func:`~sklearn.utils.multiclass.type_of_target` is `"unknown"`.
254+
255+
.. versionadded:: 1.6
256+
251257
Returns
252258
-------
253259
target_type : str
@@ -298,6 +304,17 @@ def type_of_target(y, input_name=""):
298304
'multilabel-indicator'
299305
"""
300306
xp, is_array_api_compliant = get_namespace(y)
307+
308+
def _raise_or_return():
309+
"""Depending on the value of raise_unknown, either raise an error or return
310+
'unknown'.
311+
"""
312+
if raise_unknown:
313+
input = input_name if input_name else "data"
314+
raise ValueError(f"Unknown label type for {input}: {y!r}")
315+
else:
316+
return "unknown"
317+
301318
valid = (
302319
(isinstance(y, Sequence) or issparse(y) or hasattr(y, "__array__"))
303320
and not isinstance(y, str)
@@ -374,17 +391,17 @@ def type_of_target(y, input_name=""):
374391
# Invalid inputs
375392
if y.ndim not in (1, 2):
376393
# Number of dimension greater than 2: [[[1, 2]]]
377-
return "unknown"
394+
return _raise_or_return()
378395
if not min(y.shape):
379396
# Empty ndarray: []/[[]]
380397
if y.ndim == 1:
381398
# 1-D empty array: []
382399
return "binary" # []
383400
# 2-D empty array: [[]]
384-
return "unknown"
401+
return _raise_or_return()
385402
if not issparse(y) and y.dtype == object and not isinstance(y.flat[0], str):
386403
# [obj_1] and not ["label_1"]
387-
return "unknown"
404+
return _raise_or_return()
388405

389406
# Check if multioutput
390407
if y.ndim == 2 and y.shape[1] > 1:

sklearn/utils/tests/test_estimator_checks.py

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,11 @@
1515
from sklearn import config_context, get_config
1616
from sklearn.base import BaseEstimator, ClassifierMixin, OutlierMixin
1717
from sklearn.cluster import MiniBatchKMeans
18-
from sklearn.datasets import make_multilabel_classification
18+
from sklearn.datasets import (
19+
load_iris,
20+
make_multilabel_classification,
21+
)
1922
from sklearn.decomposition import PCA
20-
from sklearn.ensemble import ExtraTreesClassifier
2123
from sklearn.exceptions import ConvergenceWarning, SkipTestWarning
2224
from sklearn.linear_model import (
2325
LinearRegression,
@@ -46,6 +48,7 @@
4648
check_array_api_input,
4749
check_class_weight_balanced_linear_classifier,
4850
check_classifier_data_not_an_array,
51+
check_classifier_not_supporting_multiclass,
4952
check_classifiers_multilabel_output_format_decision_function,
5053
check_classifiers_multilabel_output_format_predict,
5154
check_classifiers_multilabel_output_format_predict_proba,
@@ -79,6 +82,7 @@
7982
)
8083
from sklearn.utils.fixes import CSR_CONTAINERS, SPARRAY_PRESENT
8184
from sklearn.utils.metaestimators import available_if
85+
from sklearn.utils.multiclass import type_of_target
8286
from sklearn.utils.validation import (
8387
check_array,
8488
check_is_fitted,
@@ -473,6 +477,15 @@ def partial_fit(self, X, y, classes=None, sample_weight=None):
473477

474478

475479
class TaggedBinaryClassifier(UntaggedBinaryClassifier):
480+
def fit(self, X, y):
481+
y_type = type_of_target(y, input_name="y", raise_unknown=True)
482+
if y_type != "binary":
483+
raise ValueError(
484+
"Only binary classification is supported. The type of the target "
485+
f"is {y_type}."
486+
)
487+
return super().fit(X, y)
488+
476489
# Toy classifier that only supports binary classification.
477490
def __sklearn_tags__(self):
478491
tags = super().__sklearn_tags__()
@@ -800,7 +813,6 @@ def test_check_estimator_transformer_no_mixin():
800813

801814
def test_check_estimator_clones():
802815
# check that check_estimator doesn't modify the estimator it receives
803-
from sklearn.datasets import load_iris
804816

805817
iris = load_iris()
806818

@@ -809,7 +821,6 @@ def test_check_estimator_clones():
809821
LinearRegression,
810822
SGDClassifier,
811823
PCA,
812-
ExtraTreesClassifier,
813824
MiniBatchKMeans,
814825
]:
815826
# without fitting
@@ -824,7 +835,7 @@ def test_check_estimator_clones():
824835
with ignore_warnings(category=ConvergenceWarning):
825836
est = Estimator()
826837
set_random_state(est)
827-
est.fit(iris.data + 10, iris.target)
838+
est.fit(iris.data, iris.target)
828839
old_hash = joblib.hash(est)
829840
check_estimator(est)
830841
assert old_hash == joblib.hash(est)
@@ -1420,6 +1431,21 @@ def _more_tags(self):
14201431
check_estimator_tags_renamed("OkayEstimator", OkayEstimator())
14211432

14221433

1434+
def test_check_classifier_not_supporting_multiclass():
1435+
"""Check that when the estimator has the wrong tags.classifier_tags.multi_class
1436+
set, the test fails."""
1437+
1438+
class BadEstimator(BaseEstimator):
1439+
# we don't actually need to define the tag here since we're running the test
1440+
# manually, and BaseEstimator defaults to multi_output=False.
1441+
def fit(self, X, y):
1442+
return self
1443+
1444+
msg = "The estimator tag `tags.classifier_tags.multi_class` is False"
1445+
with raises(AssertionError, match=msg):
1446+
check_classifier_not_supporting_multiclass("BadEstimator", BadEstimator())
1447+
1448+
14231449
# Test that set_output doesn't make the tests to fail.
14241450
def test_estimator_with_set_output():
14251451
# Doing this since pytest is not available for this file.

0 commit comments

Comments
 (0)