Skip to content

Commit 9c266cf

Browse files
authored
MAINT Param validation: decorate all estimators with _fit_context (scikit-learn#26473)
1 parent 96878ba commit 9c266cf

File tree

109 files changed

+496
-336
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

109 files changed

+496
-336
lines changed

sklearn/base.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from .utils.validation import _num_features
2828
from .utils.validation import _check_feature_names_in
2929
from .utils.validation import _generate_get_feature_names_out
30-
from .utils.validation import check_is_fitted
30+
from .utils.validation import _is_fitted, check_is_fitted
3131
from .utils._metadata_requests import _MetadataRequester
3232
from .utils.validation import _get_feature_names
3333
from .utils._estimator_html_repr import estimator_html_repr
@@ -1131,7 +1131,13 @@ def decorator(fit_method):
11311131
@functools.wraps(fit_method)
11321132
def wrapper(estimator, *args, **kwargs):
11331133
global_skip_validation = get_config()["skip_parameter_validation"]
1134-
if not global_skip_validation:
1134+
1135+
# we don't want to validate again for each call to partial_fit
1136+
partial_fit_and_fitted = (
1137+
fit_method.__name__ == "partial_fit" and _is_fitted(estimator)
1138+
)
1139+
1140+
if not global_skip_validation and not partial_fit_and_fitted:
11351141
estimator._validate_params()
11361142

11371143
with config_context(

sklearn/calibration.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
RegressorMixin,
2626
clone,
2727
MetaEstimatorMixin,
28+
_fit_context,
2829
)
2930
from .preprocessing import label_binarize, LabelEncoder
3031
from .utils import (
@@ -318,6 +319,10 @@ def _get_estimator(self):
318319

319320
return estimator
320321

322+
@_fit_context(
323+
# CalibratedClassifierCV.estimator is not validated yet
324+
prefer_skip_nested_validation=False
325+
)
321326
def fit(self, X, y, sample_weight=None, **fit_params):
322327
"""Fit the calibrated model.
323328
@@ -341,8 +346,6 @@ def fit(self, X, y, sample_weight=None, **fit_params):
341346
self : object
342347
Returns an instance of self.
343348
"""
344-
self._validate_params()
345-
346349
check_classification_targets(y)
347350
X, y = indexable(X, y)
348351
if sample_weight is not None:

sklearn/cluster/_affinity_propagation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
from ..exceptions import ConvergenceWarning
1414
from ..base import BaseEstimator, ClusterMixin
15+
from ..base import _fit_context
1516
from ..utils import check_random_state
1617
from ..utils._param_validation import Interval, StrOptions, validate_params
1718
from ..utils.validation import check_is_fitted
@@ -469,6 +470,7 @@ def __init__(
469470
def _more_tags(self):
470471
return {"pairwise": self.affinity == "precomputed"}
471472

473+
@_fit_context(prefer_skip_nested_validation=True)
472474
def fit(self, X, y=None):
473475
"""Fit the clustering from features, or affinity matrix.
474476
@@ -488,8 +490,6 @@ def fit(self, X, y=None):
488490
self
489491
Returns the instance itself.
490492
"""
491-
self._validate_params()
492-
493493
if self.affinity == "precomputed":
494494
accept_sparse = False
495495
else:

sklearn/cluster/_agglomerative.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from scipy.sparse.csgraph import connected_components
1717

1818
from ..base import BaseEstimator, ClusterMixin, ClassNamePrefixFeaturesOutMixin
19+
from ..base import _fit_context
1920
from ..metrics.pairwise import paired_distances
2021
from ..metrics.pairwise import _VALID_METRICS
2122
from ..metrics import DistanceMetric
@@ -950,6 +951,7 @@ def __init__(
950951
self.metric = metric
951952
self.compute_distances = compute_distances
952953

954+
@_fit_context(prefer_skip_nested_validation=True)
953955
def fit(self, X, y=None):
954956
"""Fit the hierarchical clustering from features, or distance matrix.
955957
@@ -968,7 +970,6 @@ def fit(self, X, y=None):
968970
self : object
969971
Returns the fitted instance.
970972
"""
971-
self._validate_params()
972973
X = self._validate_data(X, ensure_min_samples=2)
973974
return self._fit(X)
974975

@@ -1324,6 +1325,7 @@ def __init__(
13241325
)
13251326
self.pooling_func = pooling_func
13261327

1328+
@_fit_context(prefer_skip_nested_validation=True)
13271329
def fit(self, X, y=None):
13281330
"""Fit the hierarchical clustering on the data.
13291331
@@ -1340,7 +1342,6 @@ def fit(self, X, y=None):
13401342
self : object
13411343
Returns the transformer.
13421344
"""
1343-
self._validate_params()
13441345
X = self._validate_data(X, ensure_min_features=2)
13451346
super()._fit(X.T)
13461347
self._n_features_out = self.n_clusters_

sklearn/cluster/_bicluster.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
from . import KMeans, MiniBatchKMeans
1515
from ..base import BaseEstimator, BiclusterMixin
16+
from ..base import _fit_context
1617
from ..utils import check_random_state
1718
from ..utils import check_scalar
1819

@@ -118,6 +119,7 @@ def __init__(
118119
def _check_parameters(self, n_samples):
119120
"""Validate parameters depending on the input data."""
120121

122+
@_fit_context(prefer_skip_nested_validation=True)
121123
def fit(self, X, y=None):
122124
"""Create a biclustering for X.
123125
@@ -134,8 +136,6 @@ def fit(self, X, y=None):
134136
self : object
135137
SpectralBiclustering instance.
136138
"""
137-
self._validate_params()
138-
139139
X = self._validate_data(X, accept_sparse="csr", dtype=np.float64)
140140
self._check_parameters(X.shape[0])
141141
self._fit(X)

sklearn/cluster/_birch.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
ClusterMixin,
1717
BaseEstimator,
1818
ClassNamePrefixFeaturesOutMixin,
19+
_fit_context,
1920
)
2021
from ..utils.extmath import row_norms
2122
from ..utils._param_validation import Interval
@@ -501,6 +502,7 @@ def __init__(
501502
self.compute_labels = compute_labels
502503
self.copy = copy
503504

505+
@_fit_context(prefer_skip_nested_validation=True)
504506
def fit(self, X, y=None):
505507
"""
506508
Build a CF Tree for the input data.
@@ -518,9 +520,6 @@ def fit(self, X, y=None):
518520
self
519521
Fitted estimator.
520522
"""
521-
522-
self._validate_params()
523-
524523
return self._fit(X, partial=False)
525524

526525
def _fit(self, X, partial):
@@ -610,6 +609,7 @@ def _get_leaves(self):
610609
leaf_ptr = leaf_ptr.next_leaf_
611610
return leaves
612611

612+
@_fit_context(prefer_skip_nested_validation=True)
613613
def partial_fit(self, X=None, y=None):
614614
"""
615615
Online learning. Prevents rebuilding of CFTree from scratch.
@@ -629,8 +629,6 @@ def partial_fit(self, X=None, y=None):
629629
self
630630
Fitted estimator.
631631
"""
632-
self._validate_params()
633-
634632
if X is None:
635633
# Perform just the final global clustering step.
636634
self._global_clustering()

sklearn/cluster/_bisect_k_means.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import numpy as np
77
import scipy.sparse as sp
88

9+
from ..base import _fit_context
910
from ._kmeans import _BaseKMeans
1011
from ._kmeans import _kmeans_single_elkan
1112
from ._kmeans import _kmeans_single_lloyd
@@ -347,6 +348,7 @@ def _bisect(self, X, x_squared_norms, sample_weight, cluster_to_bisect):
347348

348349
cluster_to_bisect.split(best_labels, best_centers, scores)
349350

351+
@_fit_context(prefer_skip_nested_validation=True)
350352
def fit(self, X, y=None, sample_weight=None):
351353
"""Compute bisecting k-means clustering.
352354
@@ -373,8 +375,6 @@ def fit(self, X, y=None, sample_weight=None):
373375
self
374376
Fitted estimator.
375377
"""
376-
self._validate_params()
377-
378378
X = self._validate_data(
379379
X,
380380
accept_sparse="csr",

sklearn/cluster/_dbscan.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
from ..metrics.pairwise import _VALID_METRICS
1818
from ..base import BaseEstimator, ClusterMixin
19+
from ..base import _fit_context
1920
from ..utils.validation import _check_sample_weight
2021
from ..utils._param_validation import Interval, StrOptions
2122
from ..neighbors import NearestNeighbors
@@ -338,6 +339,10 @@ def __init__(
338339
self.p = p
339340
self.n_jobs = n_jobs
340341

342+
@_fit_context(
343+
# DBSCAN.metric is not validated yet
344+
prefer_skip_nested_validation=False
345+
)
341346
def fit(self, X, y=None, sample_weight=None):
342347
"""Perform DBSCAN clustering from features, or distance matrix.
343348
@@ -363,8 +368,6 @@ def fit(self, X, y=None, sample_weight=None):
363368
self : object
364369
Returns a fitted instance of self.
365370
"""
366-
self._validate_params()
367-
368371
X = self._validate_data(X, accept_sparse="csr")
369372

370373
if sample_weight is not None:

sklearn/cluster/_kmeans.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
ClusterMixin,
2424
TransformerMixin,
2525
ClassNamePrefixFeaturesOutMixin,
26+
_fit_context,
2627
)
2728
from ..metrics.pairwise import euclidean_distances
2829
from ..metrics.pairwise import _euclidean_distances
@@ -1448,6 +1449,7 @@ def _warn_mkl_vcomp(self, n_active_threads):
14481449
f" variable OMP_NUM_THREADS={n_active_threads}."
14491450
)
14501451

1452+
@_fit_context(prefer_skip_nested_validation=True)
14511453
def fit(self, X, y=None, sample_weight=None):
14521454
"""Compute k-means clustering.
14531455
@@ -1475,8 +1477,6 @@ def fit(self, X, y=None, sample_weight=None):
14751477
self : object
14761478
Fitted estimator.
14771479
"""
1478-
self._validate_params()
1479-
14801480
X = self._validate_data(
14811481
X,
14821482
accept_sparse="csr",
@@ -2057,6 +2057,7 @@ def _random_reassign(self):
20572057
return True
20582058
return False
20592059

2060+
@_fit_context(prefer_skip_nested_validation=True)
20602061
def fit(self, X, y=None, sample_weight=None):
20612062
"""Compute the centroids on X by chunking it into mini-batches.
20622063
@@ -2084,8 +2085,6 @@ def fit(self, X, y=None, sample_weight=None):
20842085
self : object
20852086
Fitted estimator.
20862087
"""
2087-
self._validate_params()
2088-
20892088
X = self._validate_data(
20902089
X,
20912090
accept_sparse="csr",
@@ -2214,6 +2213,7 @@ def fit(self, X, y=None, sample_weight=None):
22142213

22152214
return self
22162215

2216+
@_fit_context(prefer_skip_nested_validation=True)
22172217
def partial_fit(self, X, y=None, sample_weight=None):
22182218
"""Update k means estimate on a single mini-batch X.
22192219
@@ -2241,9 +2241,6 @@ def partial_fit(self, X, y=None, sample_weight=None):
22412241
"""
22422242
has_centers = hasattr(self, "cluster_centers_")
22432243

2244-
if not has_centers:
2245-
self._validate_params()
2246-
22472244
X = self._validate_data(
22482245
X,
22492246
accept_sparse="csr",

sklearn/cluster/_mean_shift.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from ..utils.parallel import delayed, Parallel
2525
from ..utils import check_random_state, gen_batches, check_array
2626
from ..base import BaseEstimator, ClusterMixin
27+
from ..base import _fit_context
2728
from ..neighbors import NearestNeighbors
2829
from ..metrics.pairwise import pairwise_distances_argmin
2930
from .._config import config_context
@@ -435,6 +436,7 @@ def __init__(
435436
self.n_jobs = n_jobs
436437
self.max_iter = max_iter
437438

439+
@_fit_context(prefer_skip_nested_validation=True)
438440
def fit(self, X, y=None):
439441
"""Perform clustering.
440442
@@ -451,7 +453,6 @@ def fit(self, X, y=None):
451453
self : object
452454
Fitted instance.
453455
"""
454-
self._validate_params()
455456
X = self._validate_data(X)
456457
bandwidth = self.bandwidth
457458
if bandwidth is None:

0 commit comments

Comments
 (0)