Skip to content

Commit 31723a6

Browse files
authored
Merge branch 'scikit-learn:main' into fork
2 parents fb5b69a + a187758 commit 31723a6

File tree

7 files changed

+233
-36
lines changed

7 files changed

+233
-36
lines changed

doc/modules/preprocessing.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -941,6 +941,7 @@ learned in :meth:`~TargetEncoder.fit_transform`.
941941
.. topic:: Examples:
942942

943943
* :ref:`sphx_glr_auto_examples_preprocessing_plot_target_encoder.py`
944+
* :ref:`sphx_glr_auto_examples_preprocessing_plot_target_encoder_cross_val.py`
944945

945946
.. topic:: References
946947

Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
"""
2+
==========================================
3+
Target Encoder's Internal Cross Validation
4+
==========================================
5+
6+
.. currentmodule:: sklearn.preprocessing
7+
8+
The :class:`TargetEnocoder` replaces each category of a categorical feature with
9+
the mean of the target variable for that category. This method is useful
10+
in cases where there is a strong relationship between the categorical feature
11+
and the target. To prevent overfitting, :meth:`TargetEncoder.fit_transform` uses
12+
interval cross validation to encode the training data to be used by a downstream
13+
model. In this example, we demonstrate the importance of the cross validation
14+
procedure to prevent overfitting.
15+
"""
16+
17+
# %%
18+
# Create Synthetic Dataset
19+
# ========================
20+
# For this example, we build a dataset with three categorical features: an informative
21+
# feature with medium cardinality, an uninformative feature with medium cardinality,
22+
# and an uninformative feature with high cardinality. First, we generate the informative
23+
# feature:
24+
from sklearn.preprocessing import KBinsDiscretizer
25+
import numpy as np
26+
27+
n_samples = 50_000
28+
29+
rng = np.random.RandomState(42)
30+
y = rng.randn(n_samples)
31+
noise = 0.5 * rng.randn(n_samples)
32+
n_categories = 100
33+
34+
kbins = KBinsDiscretizer(
35+
n_bins=n_categories, encode="ordinal", strategy="uniform", random_state=rng
36+
)
37+
X_informative = kbins.fit_transform((y + noise).reshape(-1, 1))
38+
39+
# Remove the linear relationship between y and the bin index by permuting the values of
40+
# X_informative
41+
permuted_categories = rng.permutation(n_categories)
42+
X_informative = permuted_categories[X_informative.astype(np.int32)]
43+
44+
# %%
45+
# The uninformative feature with medium cardinality is generated by permuting the
46+
# informative feature and removing the relationship with the target:
47+
X_shuffled = rng.permutation(X_informative)
48+
49+
# %%
50+
# The uninformative feature with high cardinality is generated so that is independent of
51+
# the target variable. We will show that target encoding without cross validation will
52+
# cause catastrophic overfitting for the downstream regressor. These high cardinality
53+
# features are basically unique identifiers for samples which should generally be
54+
# removed from machine learning dataset. In this example, we generate them to show how
55+
# :class:`TargetEncoder`'s default cross validation behavior mitigates the overfitting
56+
# issue automatically.
57+
X_near_unique_categories = rng.choice(
58+
int(0.9 * n_samples), size=n_samples, replace=True
59+
).reshape(-1, 1)
60+
61+
# %%
62+
# Finally, we assemble the dataset and perform a train test split:
63+
from sklearn.model_selection import train_test_split
64+
import pandas as pd
65+
66+
X = pd.DataFrame(
67+
np.concatenate(
68+
[X_informative, X_shuffled, X_near_unique_categories],
69+
axis=1,
70+
),
71+
columns=["informative", "shuffled", "near_unique"],
72+
)
73+
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)
74+
75+
# %%
76+
# Training a Ridge Regressor
77+
# ==========================
78+
# In this section, we train a ridge regressor on the dataset with and without
79+
# encoding and explore the influence of target encoder with and without the
80+
# interval cross validation. First, we see the Ridge model trained on the
81+
# raw features will have low performance, because the order of the informative
82+
# feature is not informative:
83+
from sklearn.linear_model import Ridge
84+
import sklearn
85+
86+
# Configure transformers to always output DataFrames
87+
sklearn.set_config(transform_output="pandas")
88+
89+
ridge = Ridge(alpha=1e-6, solver="lsqr", fit_intercept=False)
90+
91+
raw_model = ridge.fit(X_train, y_train)
92+
print("Raw Model score on training set: ", raw_model.score(X_train, y_train))
93+
print("Raw Model score on test set: ", raw_model.score(X_test, y_test))
94+
95+
# %%
96+
# Next, we create a pipeline with the target encoder and ridge model. The pipeline
97+
# uses :meth:`TargetEncoder.fit_transform` which uses cross validation. We see that
98+
# the model fits the data well and generalizes to the test set:
99+
from sklearn.pipeline import make_pipeline
100+
from sklearn.preprocessing import TargetEncoder
101+
102+
model_with_cv = make_pipeline(TargetEncoder(random_state=0), ridge)
103+
model_with_cv.fit(X_train, y_train)
104+
print("Model with CV on training set: ", model_with_cv.score(X_train, y_train))
105+
print("Model with CV on test set: ", model_with_cv.score(X_test, y_test))
106+
107+
# %%
108+
# The coefficients of the linear model shows that most of the weight is on the
109+
# feature at column index 0, which is the informative feature
110+
import pandas as pd
111+
import matplotlib.pyplot as plt
112+
113+
plt.rcParams["figure.constrained_layout.use"] = True
114+
115+
coefs_cv = pd.Series(
116+
model_with_cv[-1].coef_, index=model_with_cv[-1].feature_names_in_
117+
).sort_values()
118+
_ = coefs_cv.plot(kind="barh")
119+
120+
# %%
121+
# While :meth:`TargetEncoder.fit_transform` uses an interval cross validation,
122+
# :meth:`TargetEncoder.transform` itself does not perform any cross validation.
123+
# It uses the aggregation of the complete training set to transform the categorical
124+
# features. Thus, we can use :meth:`TargetEncoder.fit` followed by
125+
# :meth:`TargetEncoder.transform` to disable the cross validation. This encoding
126+
# is then passed to the ridge model.
127+
target_encoder = TargetEncoder(random_state=0)
128+
target_encoder.fit(X_train, y_train)
129+
X_train_no_cv_encoding = target_encoder.transform(X_train)
130+
X_test_no_cv_encoding = target_encoder.transform(X_test)
131+
132+
model_no_cv = ridge.fit(X_train_no_cv_encoding, y_train)
133+
134+
# %%
135+
# We evaluate the model on the non-cross validated encoding and see that it overfits:
136+
print(
137+
"Model without CV on training set: ",
138+
model_no_cv.score(X_train_no_cv_encoding, y_train),
139+
)
140+
print(
141+
"Model without CV on test set: ", model_no_cv.score(X_test_no_cv_encoding, y_test)
142+
)
143+
144+
# %%
145+
# The ridge model overfits, because it assigns more weight to the extremely high
146+
# cardinality feature relative to the informative feature.
147+
coefs_no_cv = pd.Series(
148+
model_no_cv.coef_, index=model_no_cv.feature_names_in_
149+
).sort_values()
150+
_ = coefs_no_cv.plot(kind="barh")
151+
152+
# %%
153+
# Conclusion
154+
# ==========
155+
# This example demonstrates the importance of :class:`TargetEncoder`'s interval cross
156+
# validation. It is important to use :meth:`TargetEncoder.fit_transform` to encode
157+
# training data before passing it to a machine learning model. When a
158+
# :class:`TargetEncoder` is a part of a :class:`~sklearn.pipeline.Pipeline` and the
159+
# pipeline is fitted, the pipeline will correctly call
160+
# :meth:`TargetEncoder.fit_transform` and pass the encoding along.

examples/text/plot_document_clustering.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
documents by topics using a `Bag of Words approach
88
<https://en.wikipedia.org/wiki/Bag-of-words_model>`_.
99
10-
Two algorithms are demoed: :class:`~sklearn.cluster.KMeans` and its more
10+
Two algorithms are demonstrated, namely :class:`~sklearn.cluster.KMeans` and its more
1111
scalable variant, :class:`~sklearn.cluster.MiniBatchKMeans`. Additionally,
1212
latent semantic analysis is used to reduce dimensionality and discover latent
1313
patterns in the data.

sklearn/datasets/_openml.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -519,19 +519,21 @@ def _open_url_and_load_gzip_file(url, data_home, n_retries, delay, arff_params):
519519
url, data_home, n_retries, delay, arff_params
520520
)
521521
except Exception as exc:
522-
if parser == "pandas":
523-
from pandas.errors import ParserError
524-
525-
if isinstance(exc, ParserError):
526-
# A parsing error could come from providing the wrong quotechar
527-
# to pandas. By default, we use a double quote. Thus, we retry
528-
# with a single quote before to raise the error.
529-
arff_params["read_csv_kwargs"] = {"quotechar": "'"}
530-
X, y, frame, categories = _open_url_and_load_gzip_file(
531-
url, data_home, n_retries, delay, arff_params
532-
)
533-
else:
534-
raise
522+
if parser != "pandas":
523+
raise
524+
525+
from pandas.errors import ParserError
526+
527+
if not isinstance(exc, ParserError):
528+
raise
529+
530+
# A parsing error could come from providing the wrong quotechar
531+
# to pandas. By default, we use a double quote. Thus, we retry
532+
# with a single quote before to raise the error.
533+
arff_params["read_csv_kwargs"] = {"quotechar": "'"}
534+
X, y, frame, categories = _open_url_and_load_gzip_file(
535+
url, data_home, n_retries, delay, arff_params
536+
)
535537

536538
return X, y, frame, categories
537539

sklearn/model_selection/_validation.py

Lines changed: 34 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import numbers
1616
import time
1717
from functools import partial
18+
from numbers import Real
1819
from traceback import format_exc
1920
from contextlib import suppress
2021
from collections import Counter
@@ -29,7 +30,14 @@
2930
from ..utils.validation import _num_samples
3031
from ..utils.parallel import delayed, Parallel
3132
from ..utils.metaestimators import _safe_split
33+
from ..utils._param_validation import (
34+
HasMethods,
35+
Integral,
36+
StrOptions,
37+
validate_params,
38+
)
3239
from ..metrics import check_scoring
40+
from ..metrics import get_scorer_names
3341
from ..metrics._scorer import _check_multimetric_scoring, _MultimetricScorer
3442
from ..exceptions import FitFailedWarning
3543
from ._split import check_cv
@@ -46,6 +54,31 @@
4654
]
4755

4856

57+
@validate_params(
58+
{
59+
"estimator": [HasMethods("fit")],
60+
"X": ["array-like", "sparse matrix"],
61+
"y": ["array-like", None],
62+
"groups": ["array-like", None],
63+
"scoring": [
64+
StrOptions(set(get_scorer_names())),
65+
callable,
66+
list,
67+
tuple,
68+
dict,
69+
None,
70+
],
71+
"cv": ["cv_object"],
72+
"n_jobs": [Integral, None],
73+
"verbose": ["verbose"],
74+
"fit_params": [dict, None],
75+
"pre_dispatch": [Integral, str],
76+
"return_train_score": ["boolean"],
77+
"return_estimator": ["boolean"],
78+
"return_indices": ["boolean"],
79+
"error_score": [StrOptions({"raise"}), Real],
80+
}
81+
)
4982
def cross_validate(
5083
estimator,
5184
X,
@@ -72,7 +105,7 @@ def cross_validate(
72105
estimator : estimator object implementing 'fit'
73106
The object to use to fit the data.
74107
75-
X : array-like of shape (n_samples, n_features)
108+
X : {array-like, sparse matrix} of shape (n_samples, n_features)
76109
The data to fit. Can be for example a list, or an array.
77110
78111
y : array-like of shape (n_samples,) or (n_samples, n_outputs), default=None
@@ -141,11 +174,6 @@ def cross_validate(
141174
explosion of memory consumption when more jobs get dispatched
142175
than CPUs can process. This parameter can be:
143176
144-
- None, in which case all the jobs are immediately
145-
created and spawned. Use this for lightweight and
146-
fast-running jobs, to avoid delays due to on-demand
147-
spawning of the jobs
148-
149177
- An int, giving the exact number of total jobs that are
150178
spawned
151179

sklearn/model_selection/tests/test_validation.py

Lines changed: 21 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import pytest
1111
import numpy as np
1212
from scipy.sparse import coo_matrix, csr_matrix
13+
from scipy.sparse import issparse
1314
from sklearn.exceptions import FitFailedWarning
1415

1516
from sklearn.model_selection.tests.test_search import FailingClassifier
@@ -354,18 +355,10 @@ def test_cross_validate_invalid_scoring_param():
354355
with pytest.raises(ValueError, match=error_message_regexp):
355356
cross_validate(estimator, X, y, scoring=[[make_scorer(precision_score)]])
356357

357-
error_message_regexp = (
358-
".*scoring is invalid.*Refer to the scoring glossary for details:.*"
359-
)
360-
361358
# Empty dict should raise invalid scoring error
362359
with pytest.raises(ValueError, match="An empty dict"):
363360
cross_validate(estimator, X, y, scoring=(dict()))
364361

365-
# And so should any other invalid entry
366-
with pytest.raises(ValueError, match=error_message_regexp):
367-
cross_validate(estimator, X, y, scoring=5)
368-
369362
multiclass_scorer = make_scorer(precision_recall_fscore_support)
370363

371364
# Multiclass Scorers that return multiple values are not supported yet
@@ -382,9 +375,6 @@ def test_cross_validate_invalid_scoring_param():
382375
with pytest.warns(UserWarning, match=warning_message):
383376
cross_validate(estimator, X, y, scoring={"foo": multiclass_scorer})
384377

385-
with pytest.raises(ValueError, match="'mse' is not a valid scoring value."):
386-
cross_validate(SVC(), X, y, scoring="mse")
387-
388378

389379
def test_cross_validate_nested_estimator():
390380
# Non-regression test to ensure that nested
@@ -405,7 +395,8 @@ def test_cross_validate_nested_estimator():
405395
assert all(isinstance(estimator, Pipeline) for estimator in estimators)
406396

407397

408-
def test_cross_validate():
398+
@pytest.mark.parametrize("use_sparse", [False, True])
399+
def test_cross_validate(use_sparse: bool):
409400
# Compute train and test mse/r2 scores
410401
cv = KFold()
411402

@@ -417,6 +408,10 @@ def test_cross_validate():
417408
X_clf, y_clf = make_classification(n_samples=30, random_state=0)
418409
clf = SVC(kernel="linear", random_state=0)
419410

411+
if use_sparse:
412+
X_reg = csr_matrix(X_reg)
413+
X_clf = csr_matrix(X_clf)
414+
420415
for X, y, est in ((X_reg, y_reg, reg), (X_clf, y_clf, clf)):
421416
# It's okay to evaluate regression metrics on classification too
422417
mse_scorer = check_scoring(est, scoring="neg_mean_squared_error")
@@ -510,7 +505,15 @@ def check_cross_validate_single_metric(clf, X, y, scores, cv):
510505
clf, X, y, scoring="neg_mean_squared_error", return_estimator=True, cv=cv
511506
)
512507
for k, est in enumerate(mse_scores_dict["estimator"]):
513-
assert_almost_equal(est.coef_, fitted_estimators[k].coef_)
508+
est_coef = est.coef_.copy()
509+
if issparse(est_coef):
510+
est_coef = est_coef.toarray()
511+
512+
fitted_est_coef = fitted_estimators[k].coef_.copy()
513+
if issparse(fitted_est_coef):
514+
fitted_est_coef = fitted_est_coef.toarray()
515+
516+
assert_almost_equal(est_coef, fitted_est_coef)
514517
assert_almost_equal(est.intercept_, fitted_estimators[k].intercept_)
515518

516519

@@ -2104,10 +2107,12 @@ def test_fit_and_score_failing():
21042107
"error_score must be the string 'raise' or a numeric value. (Hint: if "
21052108
"using 'raise', please make sure that it has been spelled correctly.)"
21062109
)
2107-
with pytest.raises(ValueError, match=error_message):
2108-
cross_validate(failing_clf, X, cv=3, error_score="unvalid-string")
21092110

2110-
with pytest.raises(ValueError, match=error_message):
2111+
error_message_cross_validate = (
2112+
"The 'error_score' parameter of cross_validate must be .*. Got .* instead."
2113+
)
2114+
2115+
with pytest.raises(ValueError, match=error_message_cross_validate):
21112116
cross_val_score(failing_clf, X, cv=3, error_score="unvalid-string")
21122117

21132118
with pytest.raises(ValueError, match=error_message):

sklearn/tests/test_public_functions.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,7 @@ def _check_function_param_validation(
246246
"sklearn.metrics.roc_curve",
247247
"sklearn.metrics.top_k_accuracy_score",
248248
"sklearn.metrics.zero_one_loss",
249+
"sklearn.model_selection.cross_validate",
249250
"sklearn.model_selection.train_test_split",
250251
"sklearn.neighbors.sort_graph_by_row_values",
251252
"sklearn.preprocessing.add_dummy_feature",

0 commit comments

Comments
 (0)