Skip to content

Commit 8f620fd

Browse files
nithish08StefanieSengeradrinjalali
authored
MAINT move _estimator_has function to utils (scikit-learn#29319)
Co-authored-by: Stefanie Senger <91849487+StefanieSenger@users.noreply.github.com> Co-authored-by: adrinjalali <adrin.jalali@gmail.com>
1 parent 102663d commit 8f620fd

File tree

9 files changed

+168
-125
lines changed

9 files changed

+168
-125
lines changed

sklearn/ensemble/_bagging.py

Lines changed: 4 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
_check_method_params,
4242
_check_sample_weight,
4343
_deprecate_positional_args,
44+
_estimator_has,
4445
check_is_fitted,
4546
has_fit_parameter,
4647
validate_data,
@@ -269,22 +270,6 @@ def _parallel_predict_regression(estimators, estimators_features, X):
269270
)
270271

271272

272-
def _estimator_has(attr):
273-
"""Check if we can delegate a method to the underlying estimator.
274-
275-
First, we check the first fitted estimator if available, otherwise we
276-
check the estimator attribute.
277-
"""
278-
279-
def check(self):
280-
if hasattr(self, "estimators_"):
281-
return hasattr(self.estimators_[0], attr)
282-
else: # self.estimator is not None
283-
return hasattr(self.estimator, attr)
284-
285-
return check
286-
287-
288273
class BaseBagging(BaseEnsemble, metaclass=ABCMeta):
289274
"""Base class for Bagging meta-estimator.
290275
@@ -1033,7 +1018,9 @@ def predict_log_proba(self, X):
10331018

10341019
return log_proba
10351020

1036-
@available_if(_estimator_has("decision_function"))
1021+
@available_if(
1022+
_estimator_has("decision_function", delegates=("estimators_", "estimator"))
1023+
)
10371024
def decision_function(self, X):
10381025
"""Average of the decision functions of the base classifiers.
10391026

sklearn/ensemble/_stacking.py

Lines changed: 20 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -40,31 +40,13 @@
4040
_check_feature_names_in,
4141
_check_response_method,
4242
_deprecate_positional_args,
43+
_estimator_has,
4344
check_is_fitted,
4445
column_or_1d,
4546
)
4647
from ._base import _BaseHeterogeneousEnsemble, _fit_single_estimator
4748

4849

49-
def _estimator_has(attr):
50-
"""Check if we can delegate a method to the underlying estimator.
51-
52-
First, we check the fitted `final_estimator_` if available, otherwise we check the
53-
unfitted `final_estimator`. We raise the original `AttributeError` if `attr` does
54-
not exist. This function is used together with `available_if`.
55-
"""
56-
57-
def check(self):
58-
if hasattr(self, "final_estimator_"):
59-
getattr(self.final_estimator_, attr)
60-
else:
61-
getattr(self.final_estimator, attr)
62-
63-
return True
64-
65-
return check
66-
67-
6850
class _BaseStacking(TransformerMixin, _BaseHeterogeneousEnsemble, metaclass=ABCMeta):
6951
"""Base class for stacking method."""
7052

@@ -364,7 +346,9 @@ def get_feature_names_out(self, input_features=None):
364346

365347
return np.asarray(meta_names, dtype=object)
366348

367-
@available_if(_estimator_has("predict"))
349+
@available_if(
350+
_estimator_has("predict", delegates=("final_estimator_", "final_estimator"))
351+
)
368352
def predict(self, X, **predict_params):
369353
"""Predict target for X.
370354
@@ -732,7 +716,9 @@ def fit(self, X, y, *, sample_weight=None, **fit_params):
732716
fit_params["sample_weight"] = sample_weight
733717
return super().fit(X, y_encoded, **fit_params)
734718

735-
@available_if(_estimator_has("predict"))
719+
@available_if(
720+
_estimator_has("predict", delegates=("final_estimator_", "final_estimator"))
721+
)
736722
def predict(self, X, **predict_params):
737723
"""Predict target for X.
738724
@@ -785,7 +771,11 @@ def predict(self, X, **predict_params):
785771
y_pred = self._label_encoder.inverse_transform(y_pred)
786772
return y_pred
787773

788-
@available_if(_estimator_has("predict_proba"))
774+
@available_if(
775+
_estimator_has(
776+
"predict_proba", delegates=("final_estimator_", "final_estimator")
777+
)
778+
)
789779
def predict_proba(self, X):
790780
"""Predict class probabilities for `X` using the final estimator.
791781
@@ -809,7 +799,11 @@ def predict_proba(self, X):
809799
y_pred = np.array([preds[:, 0] for preds in y_pred]).T
810800
return y_pred
811801

812-
@available_if(_estimator_has("decision_function"))
802+
@available_if(
803+
_estimator_has(
804+
"decision_function", delegates=("final_estimator_", "final_estimator")
805+
)
806+
)
813807
def decision_function(self, X):
814808
"""Decision function for samples in `X` using the final estimator.
815809
@@ -1125,7 +1119,9 @@ def fit_transform(self, X, y, *, sample_weight=None, **fit_params):
11251119
fit_params["sample_weight"] = sample_weight
11261120
return super().fit_transform(X, y, **fit_params)
11271121

1128-
@available_if(_estimator_has("predict"))
1122+
@available_if(
1123+
_estimator_has("predict", delegates=("final_estimator_", "final_estimator"))
1124+
)
11291125
def predict(self, X, **predict_params):
11301126
"""Predict target for X.
11311127

sklearn/feature_selection/_from_model.py

Lines changed: 1 addition & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from ..utils.metaestimators import available_if
2020
from ..utils.validation import (
2121
_check_feature_names,
22+
_estimator_has,
2223
_num_features,
2324
check_is_fitted,
2425
check_scalar,
@@ -76,25 +77,6 @@ def _calculate_threshold(estimator, importances, threshold):
7677
return threshold
7778

7879

79-
def _estimator_has(attr):
80-
"""Check if we can delegate a method to the underlying estimator.
81-
82-
First, we check the fitted `estimator_` if available, otherwise we check the
83-
unfitted `estimator`. We raise the original `AttributeError` if `attr` does
84-
not exist. This function is used together with `available_if`.
85-
"""
86-
87-
def check(self):
88-
if hasattr(self, "estimator_"):
89-
getattr(self.estimator_, attr)
90-
else:
91-
getattr(self.estimator, attr)
92-
93-
return True
94-
95-
return check
96-
97-
9880
class SelectFromModel(MetaEstimatorMixin, SelectorMixin, BaseEstimator):
9981
"""Meta-transformer for selecting features based on importance weights.
10082

sklearn/feature_selection/_rfe.py

Lines changed: 1 addition & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from ..utils.validation import (
3030
_check_method_params,
3131
_deprecate_positional_args,
32+
_estimator_has,
3233
check_is_fitted,
3334
validate_data,
3435
)
@@ -64,25 +65,6 @@ def _rfe_single_fit(rfe, estimator, X, y, train, test, scorer, routed_params):
6465
return rfe.step_scores_, rfe.step_n_features_
6566

6667

67-
def _estimator_has(attr):
68-
"""Check if we can delegate a method to the underlying estimator.
69-
70-
First, we check the fitted `estimator_` if available, otherwise we check the
71-
unfitted `estimator`. We raise the original `AttributeError` if `attr` does
72-
not exist. This function is used together with `available_if`.
73-
"""
74-
75-
def check(self):
76-
if hasattr(self, "estimator_"):
77-
getattr(self.estimator_, attr)
78-
else:
79-
getattr(self.estimator, attr)
80-
81-
return True
82-
83-
return check
84-
85-
8668
class RFE(SelectorMixin, MetaEstimatorMixin, BaseEstimator):
8769
"""Feature ranking with recursive feature elimination.
8870

sklearn/model_selection/_classification_threshold.py

Lines changed: 1 addition & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
from ..utils.parallel import Parallel, delayed
3737
from ..utils.validation import (
3838
_check_method_params,
39+
_estimator_has,
3940
_num_samples,
4041
check_is_fitted,
4142
indexable,
@@ -50,23 +51,6 @@ def _check_is_fitted(estimator):
5051
check_is_fitted(estimator, "estimator_")
5152

5253

53-
def _estimator_has(attr):
54-
"""Check if we can delegate a method to the underlying estimator.
55-
56-
First, we check the fitted estimator if available, otherwise we
57-
check the unfitted estimator.
58-
"""
59-
60-
def check(self):
61-
if hasattr(self, "estimator_"):
62-
getattr(self.estimator_, attr)
63-
else:
64-
getattr(self.estimator, attr)
65-
return True
66-
67-
return check
68-
69-
7054
class BaseThresholdClassifier(ClassifierMixin, MetaEstimatorMixin, BaseEstimator):
7155
"""Base class for binary classifiers that set a non-default decision threshold.
7256

sklearn/model_selection/_search.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -356,7 +356,7 @@ def _check_refit(search_cv, attr):
356356
)
357357

358358

359-
def _estimator_has(attr):
359+
def _search_estimator_has(attr):
360360
"""Check if we can delegate a method to the underlying estimator.
361361
362362
Calling a prediction method will only be available if `refit=True`. In
@@ -555,7 +555,7 @@ def score(self, X, y=None, **params):
555555
score = score[self.refit]
556556
return score
557557

558-
@available_if(_estimator_has("score_samples"))
558+
@available_if(_search_estimator_has("score_samples"))
559559
def score_samples(self, X):
560560
"""Call score_samples on the estimator with the best found parameters.
561561
@@ -578,7 +578,7 @@ def score_samples(self, X):
578578
check_is_fitted(self)
579579
return self.best_estimator_.score_samples(X)
580580

581-
@available_if(_estimator_has("predict"))
581+
@available_if(_search_estimator_has("predict"))
582582
def predict(self, X):
583583
"""Call predict on the estimator with the best found parameters.
584584
@@ -600,7 +600,7 @@ def predict(self, X):
600600
check_is_fitted(self)
601601
return self.best_estimator_.predict(X)
602602

603-
@available_if(_estimator_has("predict_proba"))
603+
@available_if(_search_estimator_has("predict_proba"))
604604
def predict_proba(self, X):
605605
"""Call predict_proba on the estimator with the best found parameters.
606606
@@ -623,7 +623,7 @@ def predict_proba(self, X):
623623
check_is_fitted(self)
624624
return self.best_estimator_.predict_proba(X)
625625

626-
@available_if(_estimator_has("predict_log_proba"))
626+
@available_if(_search_estimator_has("predict_log_proba"))
627627
def predict_log_proba(self, X):
628628
"""Call predict_log_proba on the estimator with the best found parameters.
629629
@@ -646,7 +646,7 @@ def predict_log_proba(self, X):
646646
check_is_fitted(self)
647647
return self.best_estimator_.predict_log_proba(X)
648648

649-
@available_if(_estimator_has("decision_function"))
649+
@available_if(_search_estimator_has("decision_function"))
650650
def decision_function(self, X):
651651
"""Call decision_function on the estimator with the best found parameters.
652652
@@ -669,7 +669,7 @@ def decision_function(self, X):
669669
check_is_fitted(self)
670670
return self.best_estimator_.decision_function(X)
671671

672-
@available_if(_estimator_has("transform"))
672+
@available_if(_search_estimator_has("transform"))
673673
def transform(self, X):
674674
"""Call transform on the estimator with the best found parameters.
675675
@@ -691,7 +691,7 @@ def transform(self, X):
691691
check_is_fitted(self)
692692
return self.best_estimator_.transform(X)
693693

694-
@available_if(_estimator_has("inverse_transform"))
694+
@available_if(_search_estimator_has("inverse_transform"))
695695
def inverse_transform(self, X=None, Xt=None):
696696
"""Call inverse_transform on the estimator with the best found params.
697697
@@ -746,7 +746,7 @@ def classes_(self):
746746
747747
Only available when `refit=True` and the estimator is a classifier.
748748
"""
749-
_estimator_has("classes_")(self)
749+
_search_estimator_has("classes_")(self)
750750
return self.best_estimator_.classes_
751751

752752
def _run_search(self, evaluate_candidates):

sklearn/semi_supervised/_self_training.py

Lines changed: 1 addition & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -17,33 +17,14 @@
1717
process_routing,
1818
)
1919
from ..utils.metaestimators import available_if
20-
from ..utils.validation import check_is_fitted, validate_data
20+
from ..utils.validation import _estimator_has, check_is_fitted, validate_data
2121

2222
__all__ = ["SelfTrainingClassifier"]
2323

2424
# Authors: The scikit-learn developers
2525
# SPDX-License-Identifier: BSD-3-Clause
2626

2727

28-
def _estimator_has(attr):
29-
"""Check if we can delegate a method to the underlying estimator.
30-
31-
First, we check the fitted `estimator_` if available, otherwise we check
32-
the unfitted `estimator`. We raise the original `AttributeError` if
33-
`attr` does not exist. This function is used together with `available_if`.
34-
"""
35-
36-
def check(self):
37-
if hasattr(self, "estimator_"):
38-
getattr(self.estimator_, attr)
39-
else:
40-
getattr(self.estimator, attr)
41-
42-
return True
43-
44-
return check
45-
46-
4728
class SelfTrainingClassifier(ClassifierMixin, MetaEstimatorMixin, BaseEstimator):
4829
"""Self-training classifier.
4930

0 commit comments

Comments
 (0)