Skip to content

Commit 7d0bec5

Browse files
adrinjalaliglemaitreCharlie-XIAO
authored
API move BaseEstimator._validate_data to utils.validation.validate_data (scikit-learn#29696)
Co-authored-by: Guillaume Lemaitre <guillaume@probabl.ai> Co-authored-by: Yao Xiao <108576690+Charlie-XIAO@users.noreply.github.com>
1 parent 5b6622b commit 7d0bec5

File tree

103 files changed

+956
-674
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

103 files changed

+956
-674
lines changed

doc/api_reference.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1183,6 +1183,7 @@ def _get_submodule(module_name, submodule_name):
11831183
"validation.check_symmetric",
11841184
"validation.column_or_1d",
11851185
"validation.has_fit_parameter",
1186+
"validation.validate_data",
11861187
],
11871188
},
11881189
{

doc/whats_new/v1.6.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,12 @@ Version 1.6.0
2525
Changes impacting many modules
2626
------------------------------
2727

28+
- |API| :func:`utils.validation.validate_data` is introduced and replaces previously
29+
private `base.BaseEstimator._validate_data` method. This is intended for third party
30+
estimator developers, who should use this function in most cases instead of
31+
:func:`utils.validation.check_array` and :func:`utils.validation.check_X_y`.
32+
:pr:`29696` by `Adrin Jalali`_.
33+
2834
- |Enhancement| `__sklearn_tags__` was introduced for setting tags in estimators.
2935
More details in :ref:`estimator_tags`.
3036
:pr:`22606` by `Thomas Fan`_ and :pr:`29677` by `Adrin Jalali`_.

sklearn/base.py

Lines changed: 0 additions & 261 deletions
Original file line numberDiff line numberDiff line change
@@ -24,14 +24,10 @@
2424
from .utils.fixes import _IS_32BIT
2525
from .utils.validation import (
2626
_check_feature_names_in,
27-
_check_y,
2827
_generate_get_feature_names_out,
29-
_get_feature_names,
3028
_is_fitted,
31-
_num_features,
3229
check_array,
3330
check_is_fitted,
34-
check_X_y,
3531
)
3632

3733

@@ -386,262 +382,6 @@ def __setstate__(self, state):
386382
def __sklearn_tags__(self):
387383
return default_tags(self)
388384

389-
def _check_n_features(self, X, reset):
390-
"""Set the `n_features_in_` attribute, or check against it.
391-
392-
Parameters
393-
----------
394-
X : {ndarray, sparse matrix} of shape (n_samples, n_features)
395-
The input samples.
396-
reset : bool
397-
If True, the `n_features_in_` attribute is set to `X.shape[1]`.
398-
If False and the attribute exists, then check that it is equal to
399-
`X.shape[1]`. If False and the attribute does *not* exist, then
400-
the check is skipped.
401-
.. note::
402-
It is recommended to call reset=True in `fit` and in the first
403-
call to `partial_fit`. All other methods that validate `X`
404-
should set `reset=False`.
405-
"""
406-
try:
407-
n_features = _num_features(X)
408-
except TypeError as e:
409-
if not reset and hasattr(self, "n_features_in_"):
410-
raise ValueError(
411-
"X does not contain any features, but "
412-
f"{self.__class__.__name__} is expecting "
413-
f"{self.n_features_in_} features"
414-
) from e
415-
# If the number of features is not defined and reset=True,
416-
# then we skip this check
417-
return
418-
419-
if reset:
420-
self.n_features_in_ = n_features
421-
return
422-
423-
if not hasattr(self, "n_features_in_"):
424-
# Skip this check if the expected number of expected input features
425-
# was not recorded by calling fit first. This is typically the case
426-
# for stateless transformers.
427-
return
428-
429-
if n_features != self.n_features_in_:
430-
raise ValueError(
431-
f"X has {n_features} features, but {self.__class__.__name__} "
432-
f"is expecting {self.n_features_in_} features as input."
433-
)
434-
435-
def _check_feature_names(self, X, *, reset):
436-
"""Set or check the `feature_names_in_` attribute.
437-
438-
.. versionadded:: 1.0
439-
440-
Parameters
441-
----------
442-
X : {ndarray, dataframe} of shape (n_samples, n_features)
443-
The input samples.
444-
445-
reset : bool
446-
Whether to reset the `feature_names_in_` attribute.
447-
If False, the input will be checked for consistency with
448-
feature names of data provided when reset was last True.
449-
.. note::
450-
It is recommended to call `reset=True` in `fit` and in the first
451-
call to `partial_fit`. All other methods that validate `X`
452-
should set `reset=False`.
453-
"""
454-
455-
if reset:
456-
feature_names_in = _get_feature_names(X)
457-
if feature_names_in is not None:
458-
self.feature_names_in_ = feature_names_in
459-
elif hasattr(self, "feature_names_in_"):
460-
# Delete the attribute when the estimator is fitted on a new dataset
461-
# that has no feature names.
462-
delattr(self, "feature_names_in_")
463-
return
464-
465-
fitted_feature_names = getattr(self, "feature_names_in_", None)
466-
X_feature_names = _get_feature_names(X)
467-
468-
if fitted_feature_names is None and X_feature_names is None:
469-
# no feature names seen in fit and in X
470-
return
471-
472-
if X_feature_names is not None and fitted_feature_names is None:
473-
warnings.warn(
474-
f"X has feature names, but {self.__class__.__name__} was fitted without"
475-
" feature names"
476-
)
477-
return
478-
479-
if X_feature_names is None and fitted_feature_names is not None:
480-
warnings.warn(
481-
"X does not have valid feature names, but"
482-
f" {self.__class__.__name__} was fitted with feature names"
483-
)
484-
return
485-
486-
# validate the feature names against the `feature_names_in_` attribute
487-
if len(fitted_feature_names) != len(X_feature_names) or np.any(
488-
fitted_feature_names != X_feature_names
489-
):
490-
message = (
491-
"The feature names should match those that were passed during fit.\n"
492-
)
493-
fitted_feature_names_set = set(fitted_feature_names)
494-
X_feature_names_set = set(X_feature_names)
495-
496-
unexpected_names = sorted(X_feature_names_set - fitted_feature_names_set)
497-
missing_names = sorted(fitted_feature_names_set - X_feature_names_set)
498-
499-
def add_names(names):
500-
output = ""
501-
max_n_names = 5
502-
for i, name in enumerate(names):
503-
if i >= max_n_names:
504-
output += "- ...\n"
505-
break
506-
output += f"- {name}\n"
507-
return output
508-
509-
if unexpected_names:
510-
message += "Feature names unseen at fit time:\n"
511-
message += add_names(unexpected_names)
512-
513-
if missing_names:
514-
message += "Feature names seen at fit time, yet now missing:\n"
515-
message += add_names(missing_names)
516-
517-
if not missing_names and not unexpected_names:
518-
message += (
519-
"Feature names must be in the same order as they were in fit.\n"
520-
)
521-
522-
raise ValueError(message)
523-
524-
def _validate_data(
525-
self,
526-
X="no_validation",
527-
y="no_validation",
528-
reset=True,
529-
validate_separately=False,
530-
cast_to_ndarray=True,
531-
**check_params,
532-
):
533-
"""Validate input data and set or check the `n_features_in_` attribute.
534-
535-
Parameters
536-
----------
537-
X : {array-like, sparse matrix, dataframe} of shape \
538-
(n_samples, n_features), default='no validation'
539-
The input samples.
540-
If `'no_validation'`, no validation is performed on `X`. This is
541-
useful for meta-estimator which can delegate input validation to
542-
their underlying estimator(s). In that case `y` must be passed and
543-
the only accepted `check_params` are `multi_output` and
544-
`y_numeric`.
545-
546-
y : array-like of shape (n_samples,), default='no_validation'
547-
The targets.
548-
549-
- If `None`, `check_array` is called on `X`. If the estimator's
550-
requires_y tag is True, then an error will be raised.
551-
- If `'no_validation'`, `check_array` is called on `X` and the
552-
estimator's requires_y tag is ignored. This is a default
553-
placeholder and is never meant to be explicitly set. In that case
554-
`X` must be passed.
555-
- Otherwise, only `y` with `_check_y` or both `X` and `y` are
556-
checked with either `check_array` or `check_X_y` depending on
557-
`validate_separately`.
558-
559-
reset : bool, default=True
560-
Whether to reset the `n_features_in_` attribute.
561-
If False, the input will be checked for consistency with data
562-
provided when reset was last True.
563-
.. note::
564-
It is recommended to call reset=True in `fit` and in the first
565-
call to `partial_fit`. All other methods that validate `X`
566-
should set `reset=False`.
567-
568-
validate_separately : False or tuple of dicts, default=False
569-
Only used if y is not None.
570-
If False, call validate_X_y(). Else, it must be a tuple of kwargs
571-
to be used for calling check_array() on X and y respectively.
572-
573-
`estimator=self` is automatically added to these dicts to generate
574-
more informative error message in case of invalid input data.
575-
576-
cast_to_ndarray : bool, default=True
577-
Cast `X` and `y` to ndarray with checks in `check_params`. If
578-
`False`, `X` and `y` are unchanged and only `feature_names_in_` and
579-
`n_features_in_` are checked.
580-
581-
**check_params : kwargs
582-
Parameters passed to :func:`sklearn.utils.check_array` or
583-
:func:`sklearn.utils.check_X_y`. Ignored if validate_separately
584-
is not False.
585-
586-
`estimator=self` is automatically added to these params to generate
587-
more informative error message in case of invalid input data.
588-
589-
Returns
590-
-------
591-
out : {ndarray, sparse matrix} or tuple of these
592-
The validated input. A tuple is returned if both `X` and `y` are
593-
validated.
594-
"""
595-
self._check_feature_names(X, reset=reset)
596-
597-
if y is None and self.__sklearn_tags__().target_tags.required:
598-
raise ValueError(
599-
f"This {self.__class__.__name__} estimator "
600-
"requires y to be passed, but the target y is None."
601-
)
602-
603-
no_val_X = isinstance(X, str) and X == "no_validation"
604-
no_val_y = y is None or isinstance(y, str) and y == "no_validation"
605-
606-
if no_val_X and no_val_y:
607-
raise ValueError("Validation should be done on X, y or both.")
608-
609-
default_check_params = {"estimator": self}
610-
check_params = {**default_check_params, **check_params}
611-
612-
if not cast_to_ndarray:
613-
if not no_val_X and no_val_y:
614-
out = X
615-
elif no_val_X and not no_val_y:
616-
out = y
617-
else:
618-
out = X, y
619-
elif not no_val_X and no_val_y:
620-
out = check_array(X, input_name="X", **check_params)
621-
elif no_val_X and not no_val_y:
622-
out = _check_y(y, **check_params)
623-
else:
624-
if validate_separately:
625-
# We need this because some estimators validate X and y
626-
# separately, and in general, separately calling check_array()
627-
# on X and y isn't equivalent to just calling check_X_y()
628-
# :(
629-
check_X_params, check_y_params = validate_separately
630-
if "estimator" not in check_X_params:
631-
check_X_params = {**default_check_params, **check_X_params}
632-
X = check_array(X, input_name="X", **check_X_params)
633-
if "estimator" not in check_y_params:
634-
check_y_params = {**default_check_params, **check_y_params}
635-
y = check_array(y, input_name="y", **check_y_params)
636-
else:
637-
X, y = check_X_y(X, y, **check_params)
638-
out = X, y
639-
640-
if not no_val_X and check_params.get("ensure_2d", True):
641-
self._check_n_features(X, reset=reset)
642-
643-
return out
644-
645385
def _validate_params(self):
646386
"""Validate types and values of constructor parameters
647387
@@ -984,7 +724,6 @@ def get_submatrix(self, i, data):
984724
Works with sparse matrices. Only works if ``rows_`` and
985725
``columns_`` attributes exist.
986726
"""
987-
from .utils.validation import check_array
988727

989728
data = check_array(data, accept_sparse="csr")
990729
row_ind, col_ind = self.get_indices(i)

sklearn/cluster/_affinity_propagation.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from ..metrics import euclidean_distances, pairwise_distances_argmin
1515
from ..utils import check_random_state
1616
from ..utils._param_validation import Interval, StrOptions, validate_params
17-
from ..utils.validation import check_is_fitted
17+
from ..utils.validation import check_is_fitted, validate_data
1818

1919

2020
def _equal_similarities_and_preferences(S, preference):
@@ -504,10 +504,10 @@ def fit(self, X, y=None):
504504
Returns the instance itself.
505505
"""
506506
if self.affinity == "precomputed":
507-
X = self._validate_data(X, copy=self.copy, force_writeable=True)
507+
X = validate_data(self, X, copy=self.copy, force_writeable=True)
508508
self.affinity_matrix_ = X
509509
else: # self.affinity == "euclidean"
510-
X = self._validate_data(X, accept_sparse="csr")
510+
X = validate_data(self, X, accept_sparse="csr")
511511
self.affinity_matrix_ = -euclidean_distances(X, squared=True)
512512

513513
if self.affinity_matrix_.shape[0] != self.affinity_matrix_.shape[1]:
@@ -559,7 +559,7 @@ def predict(self, X):
559559
Cluster labels.
560560
"""
561561
check_is_fitted(self)
562-
X = self._validate_data(X, reset=False, accept_sparse="csr")
562+
X = validate_data(self, X, reset=False, accept_sparse="csr")
563563
if not hasattr(self, "cluster_centers_"):
564564
raise ValueError(
565565
"Predict method is not supported when affinity='precomputed'."

sklearn/cluster/_agglomerative.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
validate_params,
3939
)
4040
from ..utils.graph import _fix_connected_components
41-
from ..utils.validation import check_memory
41+
from ..utils.validation import check_memory, validate_data
4242

4343
# mypy error: Module 'sklearn.cluster' has no attribute '_hierarchical_fast'
4444
from . import _hierarchical_fast as _hierarchical # type: ignore
@@ -989,7 +989,7 @@ def fit(self, X, y=None):
989989
self : object
990990
Returns the fitted instance.
991991
"""
992-
X = self._validate_data(X, ensure_min_samples=2)
992+
X = validate_data(self, X, ensure_min_samples=2)
993993
return self._fit(X)
994994

995995
def _fit(self, X):
@@ -1338,7 +1338,7 @@ def fit(self, X, y=None):
13381338
self : object
13391339
Returns the transformer.
13401340
"""
1341-
X = self._validate_data(X, ensure_min_features=2)
1341+
X = validate_data(self, X, ensure_min_features=2)
13421342
super()._fit(X.T)
13431343
self._n_features_out = self.n_clusters_
13441344
return self

sklearn/cluster/_bicluster.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from ..utils import check_random_state, check_scalar
1616
from ..utils._param_validation import Interval, StrOptions
1717
from ..utils.extmath import make_nonnegative, randomized_svd, safe_sparse_dot
18-
from ..utils.validation import assert_all_finite
18+
from ..utils.validation import assert_all_finite, validate_data
1919
from ._kmeans import KMeans, MiniBatchKMeans
2020

2121
__all__ = ["SpectralCoclustering", "SpectralBiclustering"]
@@ -131,7 +131,7 @@ def fit(self, X, y=None):
131131
self : object
132132
SpectralBiclustering instance.
133133
"""
134-
X = self._validate_data(X, accept_sparse="csr", dtype=np.float64)
134+
X = validate_data(self, X, accept_sparse="csr", dtype=np.float64)
135135
self._check_parameters(X.shape[0])
136136
self._fit(X)
137137
return self

0 commit comments

Comments
 (0)