Skip to content

Commit 95e9459

Browse files
authored
TST remove _required_parameters and improve instance generation (scikit-learn#29707)
1 parent eb29207 commit 95e9459

File tree

13 files changed

+205
-141
lines changed

13 files changed

+205
-141
lines changed

doc/developers/develop.rst

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -562,15 +562,6 @@ for your estimator's tags. For example::
562562
You can create a new subclass of :class:`~sklearn.utils.Tags` if you wish
563563
to add new tags to the existing set.
564564

565-
In addition to the tags, estimators also need to declare any non-optional
566-
parameters to ``__init__`` in the ``_required_parameters`` class attribute,
567-
which is a list or tuple. If ``_required_parameters`` is only
568-
``["estimator"]`` or ``["base_estimator"]``, then the estimator will be
569-
instantiated with an instance of ``LogisticRegression`` (or
570-
``RidgeRegression`` if the estimator is a regressor) in the tests. The choice
571-
of these two models is somewhat idiosyncratic but both should provide robust
572-
closed-form solutions.
573-
574565
.. _developer_api_set_output:
575566

576567
Developer API for `set_output`

sklearn/base.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1037,9 +1037,12 @@ def fit_predict(self, X, y=None, **kwargs):
10371037
class MetaEstimatorMixin:
10381038
"""Mixin class for all meta estimators in scikit-learn.
10391039
1040-
This mixin defines the following functionality:
1040+
This mixin is empty, and only exists to indicate that the estimator is a
1041+
meta-estimator.
10411042
1042-
- define `_required_parameters` that specify the mandatory `estimator` parameter.
1043+
.. versionchanged:: 1.6
1044+
The `_required_parameters` is now removed and is unnecessary since tests are
1045+
refactored and don't use this anymore.
10431046
10441047
Examples
10451048
--------
@@ -1061,8 +1064,6 @@ class MetaEstimatorMixin:
10611064
LogisticRegression()
10621065
"""
10631066

1064-
_required_parameters = ["estimator"]
1065-
10661067

10671068
class MultiOutputMixin:
10681069
"""Mixin to mark estimators that support multioutput."""

sklearn/compose/_column_transformer.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -287,8 +287,6 @@ class ColumnTransformer(TransformerMixin, _BaseComposition):
287287
:ref:`sphx_glr_auto_examples_compose_plot_column_transformer_mixed_types.py`.
288288
"""
289289

290-
_required_parameters = ["transformers"]
291-
292290
_parameter_constraints: dict = {
293291
"transformers": [list, Hidden(tuple)],
294292
"remainder": [
@@ -1322,6 +1320,21 @@ def get_metadata_routing(self):
13221320

13231321
return router
13241322

1323+
def __sklearn_tags__(self):
1324+
tags = super().__sklearn_tags__()
1325+
tags._xfail_checks = {
1326+
"check_estimators_empty_data_messages": "FIXME",
1327+
"check_estimators_nan_inf": "FIXME",
1328+
"check_estimator_sparse_array": "FIXME",
1329+
"check_estimator_sparse_matrix": "FIXME",
1330+
"check_transformer_data_not_an_array": "FIXME",
1331+
"check_fit1d": "FIXME",
1332+
"check_fit2d_predict1d": "FIXME",
1333+
"check_complex_data": "FIXME",
1334+
"check_fit2d_1feature": "FIXME",
1335+
}
1336+
return tags
1337+
13251338

13261339
def _check_X(X):
13271340
"""Use check_array only when necessary, e.g. on lists and other non-array-likes."""

sklearn/decomposition/_dict_learning.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1279,8 +1279,6 @@ class SparseCoder(_BaseSparseCoding, BaseEstimator):
12791279
[ 0., 1., 1., 0., 0.]])
12801280
"""
12811281

1282-
_required_parameters = ["dictionary"]
1283-
12841282
def __init__(
12851283
self,
12861284
dictionary,

sklearn/ensemble/_base.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
# SPDX-License-Identifier: BSD-3-Clause
55

66
from abc import ABCMeta, abstractmethod
7-
from typing import List
87

98
import numpy as np
109
from joblib import effective_n_jobs
@@ -106,9 +105,6 @@ class BaseEnsemble(MetaEstimatorMixin, BaseEstimator, metaclass=ABCMeta):
106105
The collection of fitted base estimators.
107106
"""
108107

109-
# overwrite _required_parameters from MetaEstimatorMixin
110-
_required_parameters: List[str] = []
111-
112108
@abstractmethod
113109
def __init__(
114110
self,
@@ -200,8 +196,6 @@ class _BaseHeterogeneousEnsemble(
200196
appear in `estimators_`.
201197
"""
202198

203-
_required_parameters = ["estimators"]
204-
205199
@property
206200
def named_estimators(self):
207201
"""Dictionary to access any fitted sub-estimators by name.

sklearn/model_selection/_classification_threshold.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,6 @@ class BaseThresholdClassifier(ClassifierMixin, MetaEstimatorMixin, BaseEstimator
8787
error.
8888
"""
8989

90-
_required_parameters = ["estimator"]
9190
_parameter_constraints: dict = {
9291
"estimator": [
9392
HasMethods(["fit", "predict_proba"]),

sklearn/model_selection/_search.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1532,8 +1532,6 @@ class GridSearchCV(BaseSearchCV):
15321532
'std_fit_time', 'std_score_time', 'std_test_score']
15331533
"""
15341534

1535-
_required_parameters = ["estimator", "param_grid"]
1536-
15371535
_parameter_constraints: dict = {
15381536
**BaseSearchCV._parameter_constraints,
15391537
"param_grid": [dict, list],
@@ -1913,8 +1911,6 @@ class RandomizedSearchCV(BaseSearchCV):
19131911
{'C': np.float64(2...), 'penalty': 'l1'}
19141912
"""
19151913

1916-
_required_parameters = ["estimator", "param_distributions"]
1917-
19181914
_parameter_constraints: dict = {
19191915
**BaseSearchCV._parameter_constraints,
19201916
"param_distributions": [dict, list],

sklearn/model_selection/_search_successive_halving.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -378,6 +378,9 @@ def __sklearn_tags__(self):
378378
"Fail during parameter check since min/max resources requires"
379379
" more samples"
380380
),
381+
"check_estimators_nan_inf": "FIXME",
382+
"check_classifiers_one_label_sample_weights": "FIXME",
383+
"check_fit2d_1feature": "FIXME",
381384
}
382385
)
383386
return tags
@@ -668,8 +671,6 @@ class HalvingGridSearchCV(BaseSuccessiveHalving):
668671
{'max_depth': None, 'min_samples_split': 10, 'n_estimators': 9}
669672
"""
670673

671-
_required_parameters = ["estimator", "param_grid"]
672-
673674
_parameter_constraints: dict = {
674675
**BaseSuccessiveHalving._parameter_constraints,
675676
"param_grid": [dict, list],
@@ -1018,8 +1019,6 @@ class HalvingRandomSearchCV(BaseSuccessiveHalving):
10181019
{'max_depth': None, 'min_samples_split': 10, 'n_estimators': 9}
10191020
"""
10201021

1021-
_required_parameters = ["estimator", "param_distributions"]
1022-
10231022
_parameter_constraints: dict = {
10241023
**BaseSuccessiveHalving._parameter_constraints,
10251024
"param_distributions": [dict, list],

sklearn/pipeline.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -152,8 +152,6 @@ class Pipeline(_BaseComposition):
152152
"""
153153

154154
# BaseEstimator interface
155-
_required_parameters = ["steps"]
156-
157155
_parameter_constraints: dict = {
158156
"steps": [list, Hidden(tuple)],
159157
"memory": [None, str, HasMethods(["cache"])],
@@ -1427,8 +1425,6 @@ class FeatureUnion(TransformerMixin, _BaseComposition):
14271425
:ref:`sphx_glr_auto_examples_compose_plot_feature_union.py`.
14281426
"""
14291427

1430-
_required_parameters = ["transformer_list"]
1431-
14321428
def __init__(
14331429
self,
14341430
transformer_list,
@@ -1882,6 +1878,15 @@ def get_metadata_routing(self):
18821878

18831879
return router
18841880

1881+
def __sklearn_tags__(self):
1882+
tags = super().__sklearn_tags__()
1883+
tags._xfail_checks = {
1884+
"check_estimators_overwrite_params": "FIXME",
1885+
"check_estimators_nan_inf": "FIXME",
1886+
"check_dont_overwrite_parameters": "FIXME",
1887+
}
1888+
return tags
1889+
18851890

18861891
def make_union(*transformers, n_jobs=None, verbose=False):
18871892
"""Construct a :class:`FeatureUnion` from the given transformers.

sklearn/tests/test_common.py

Lines changed: 8 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -26,14 +26,15 @@
2626
MeanShift,
2727
SpectralClustering,
2828
)
29+
from sklearn.compose import ColumnTransformer
2930
from sklearn.datasets import make_blobs
3031
from sklearn.exceptions import ConvergenceWarning, FitFailedWarning
32+
33+
# make it possible to discover experimental estimators when calling `all_estimators`
3134
from sklearn.experimental import (
3235
enable_halving_search_cv, # noqa
3336
enable_iterative_imputer, # noqa
3437
)
35-
36-
# make it possible to discover experimental estimators when calling `all_estimators`
3738
from sklearn.linear_model import LogisticRegression
3839
from sklearn.manifold import TSNE, Isomap, LocallyLinearEmbedding
3940
from sklearn.neighbors import (
@@ -43,7 +44,7 @@
4344
RadiusNeighborsClassifier,
4445
RadiusNeighborsRegressor,
4546
)
46-
from sklearn.pipeline import make_pipeline
47+
from sklearn.pipeline import FeatureUnion, make_pipeline
4748
from sklearn.preprocessing import (
4849
FunctionTransformer,
4950
MinMaxScaler,
@@ -54,11 +55,9 @@
5455
from sklearn.utils import all_estimators
5556
from sklearn.utils._tags import get_tags
5657
from sklearn.utils._test_common.instance_generator import (
57-
_generate_column_transformer_instances,
5858
_generate_pipeline,
5959
_generate_search_cv_instances,
6060
_get_check_estimator_ids,
61-
_set_checking_parameters,
6261
_tested_estimators,
6362
)
6463
from sklearn.utils._testing import (
@@ -139,7 +138,6 @@ def test_estimators(estimator, check, request):
139138
with ignore_warnings(
140139
category=(FutureWarning, ConvergenceWarning, UserWarning, LinAlgWarning)
141140
):
142-
_set_checking_parameters(estimator)
143141
check(estimator)
144142

145143

@@ -285,7 +283,6 @@ def check_field_types(tags, defaults):
285283
"estimator", _tested_estimators(), ids=_get_check_estimator_ids
286284
)
287285
def test_check_n_features_in_after_fitting(estimator):
288-
_set_checking_parameters(estimator)
289286
check_n_features_in_after_fitting(estimator.__class__.__name__, estimator)
290287

291288

@@ -324,7 +321,8 @@ def _estimators_that_predict_in_fit():
324321
"estimator", column_name_estimators, ids=_get_check_estimator_ids
325322
)
326323
def test_pandas_column_name_consistency(estimator):
327-
_set_checking_parameters(estimator)
324+
if isinstance(estimator, ColumnTransformer):
325+
pytest.skip("ColumnTransformer is not tested here")
328326
with ignore_warnings(category=(FutureWarning)):
329327
with warnings.catch_warnings(record=True) as record:
330328
check_dataframe_column_names_consistency(
@@ -360,7 +358,6 @@ def _include_in_get_feature_names_out_check(transformer):
360358
"transformer", GET_FEATURES_OUT_ESTIMATORS, ids=_get_check_estimator_ids
361359
)
362360
def test_transformers_get_feature_names_out(transformer):
363-
_set_checking_parameters(transformer)
364361

365362
with ignore_warnings(category=(FutureWarning)):
366363
check_transformer_get_feature_names_out(
@@ -381,7 +378,6 @@ def test_transformers_get_feature_names_out(transformer):
381378
)
382379
def test_estimators_get_feature_names_out_error(estimator):
383380
estimator_name = estimator.__class__.__name__
384-
_set_checking_parameters(estimator)
385381
check_get_feature_names_out_error(estimator_name, estimator)
386382

387383

@@ -409,14 +405,14 @@ def test_estimators_do_not_raise_errors_in_init_or_set_params(Estimator):
409405
chain(
410406
_tested_estimators(),
411407
_generate_pipeline(),
412-
_generate_column_transformer_instances(),
413408
_generate_search_cv_instances(),
414409
),
415410
ids=_get_check_estimator_ids,
416411
)
417412
def test_check_param_validation(estimator):
413+
if isinstance(estimator, FeatureUnion):
414+
pytest.skip("FeatureUnion is not tested here")
418415
name = estimator.__class__.__name__
419-
_set_checking_parameters(estimator)
420416
check_param_validation(name, estimator)
421417

422418

@@ -481,7 +477,6 @@ def test_set_output_transform(estimator):
481477
f"Skipping check_set_output_transform for {name}: Does not support"
482478
" set_output API"
483479
)
484-
_set_checking_parameters(estimator)
485480
with ignore_warnings(category=(FutureWarning)):
486481
check_set_output_transform(estimator.__class__.__name__, estimator)
487482

@@ -505,7 +500,6 @@ def test_set_output_transform_configured(estimator, check_func):
505500
f"Skipping {check_func.__name__} for {name}: Does not support"
506501
" set_output API yet"
507502
)
508-
_set_checking_parameters(estimator)
509503
with ignore_warnings(category=(FutureWarning)):
510504
check_func(estimator.__class__.__name__, estimator)
511505

@@ -523,8 +517,6 @@ def test_check_inplace_ensure_writeable(estimator):
523517
else:
524518
raise SkipTest(f"{name} doesn't require writeable input.")
525519

526-
_set_checking_parameters(estimator)
527-
528520
# The following estimators can work inplace only with certain settings
529521
if name == "HDBSCAN":
530522
estimator.set_params(metric="precomputed", algorithm="brute")

0 commit comments

Comments
 (0)