Skip to content

Commit b948fdb

Browse files
authored
DOC Improve target encoder User Guide (scikit-learn#26643)
1 parent 7af0a18 commit b948fdb

File tree

4 files changed

+45
-39
lines changed

4 files changed

+45
-39
lines changed

doc/modules/preprocessing.rst

Lines changed: 26 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -886,10 +886,11 @@ binary classification target, the target encoding is given by:
886886
S_i = \lambda_i\frac{n_{iY}}{n_i} + (1 - \lambda_i)\frac{n_Y}{n}
887887
888888
where :math:`S_i` is the encoding for category :math:`i`, :math:`n_{iY}` is the
889-
number of observations with :math:`Y=1` with category :math:`i`, :math:`n_i` is
889+
number of observations with :math:`Y=1` and category :math:`i`, :math:`n_i` is
890890
the number of observations with category :math:`i`, :math:`n_Y` is the number of
891891
observations with :math:`Y=1`, :math:`n` is the number of observations, and
892-
:math:`\lambda_i` is a shrinkage factor. The shrinkage factor is given by:
892+
:math:`\lambda_i` is a shrinkage factor for category :math:`i`. The shrinkage
893+
factor is given by:
893894

894895
.. math::
895896
\lambda_i = \frac{n_i}{m + n_i}
@@ -906,31 +907,36 @@ For continuous targets, the formulation is similar to binary classification:
906907
.. math::
907908
S_i = \lambda_i\frac{\sum_{k\in L_i}Y_k}{n_i} + (1 - \lambda_i)\frac{\sum_{k=1}^{n}Y_k}{n}
908909
909-
where :math:`L_i` is the set of observations for which :math:`X=X_i` and
910-
:math:`n_i` is the cardinality of :math:`L_i`.
911-
912-
:meth:`~TargetEncoder.fit_transform` internally relies on a cross validation
913-
scheme to prevent information from the target from leaking into the train-time
914-
representation for non-informative high-cardinality categorical variables and
915-
help prevent the downstream model to overfit spurious correlations. Note that
916-
as a result, `fit(X, y).transform(X)` does not equal `fit_transform(X, y)`. In
917-
:meth:`~TargetEncoder.fit_transform`, the training data is split into multiple
918-
folds and encodes each fold by using the encodings trained on the other folds.
919-
After cross validation is complete in :meth:`~TargetEncoder.fit_transform`, the
920-
target encoder learns one final encoding on the whole training set. This final
921-
encoding is used to encode categories in :meth:`~TargetEncoder.transform`. The
922-
following diagram shows the cross validation scheme in
923-
:meth:`~TargetEncoder.fit_transform` with the default `cv=5`:
910+
where :math:`L_i` is the set of observations with category :math:`i` and
911+
:math:`n_i` is the number of observations with category :math:`i`.
912+
913+
:meth:`~TargetEncoder.fit_transform` internally relies on a cross fitting
914+
scheme to prevent target information from leaking into the train-time
915+
representation, especially for non-informative high-cardinality categorical
916+
variables, and help prevent the downstream model from overfitting spurious
917+
correlations. Note that as a result, `fit(X, y).transform(X)` does not equal
918+
`fit_transform(X, y)`. In :meth:`~TargetEncoder.fit_transform`, the training
919+
data is split into *k* folds (determined by the `cv` parameter) and encodes each
920+
fold using the encodings trained on the other *k-1* folds. The following diagram
921+
shows the cross fitting scheme in :meth:`~TargetEncoder.fit_transform` with
922+
the default `cv=5`:
924923

925924
.. image:: ../images/target_encoder_cross_validation.svg
926925
:width: 600
927926
:align: center
928927

929-
The :meth:`~TargetEncoder.fit` method does **not** use any cross validation
928+
:meth:`~TargetEncoder.fit_transform` also learns a 'full data' encoding using
929+
the whole training set. This is never used in
930+
:meth:`~TargetEncoder.fit_transform` but is saved to the attribute `encodings_`,
931+
for use when :meth:`~TargetEncoder.transform` is called. Note that the encodings
932+
learned for each fold during the cross fitting scheme are not saved to an
933+
attribute.
934+
935+
The :meth:`~TargetEncoder.fit` method does **not** use any cross fitting
930936
schemes and learns one encoding on the entire training set, which is used to
931937
encode categories in :meth:`~TargetEncoder.transform`.
932-
:meth:`~TargetEncoder.fit`'s one encoding is the same as the final encoding
933-
learned in :meth:`~TargetEncoder.fit_transform`.
938+
This encoding is the same as the 'full data'
939+
encoding learned in :meth:`~TargetEncoder.fit_transform`.
934940

935941
.. note::
936942
:class:`TargetEncoder` considers missing values, such as `np.nan` or `None`,

examples/preprocessing/plot_target_encoder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
1313
.. note::
1414
`fit(X, y).transform(X)` does not equal `fit_transform(X, y)` because a
15-
cross-validation scheme is used in `fit_transform` for encoding. See the
15+
cross fitting scheme is used in `fit_transform` for encoding. See the
1616
:ref:`User Guide <target_encoder>`. for details.
1717
"""
1818

examples/preprocessing/plot_target_encoder_cross_val.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,16 @@
11
"""
2-
==========================================
3-
Target Encoder's Internal Cross Validation
4-
==========================================
2+
=======================================
3+
Target Encoder's Internal Cross fitting
4+
=======================================
55
66
.. currentmodule:: sklearn.preprocessing
77
88
The :class:`TargetEnocoder` replaces each category of a categorical feature with
99
the mean of the target variable for that category. This method is useful
1010
in cases where there is a strong relationship between the categorical feature
1111
and the target. To prevent overfitting, :meth:`TargetEncoder.fit_transform` uses
12-
interval cross validation to encode the training data to be used by a downstream
13-
model. In this example, we demonstrate the importance of the cross validation
12+
an internal cross fitting scheme to encode the training data to be used by a
13+
downstream model. In this example, we demonstrate the importance of the cross fitting
1414
procedure to prevent overfitting.
1515
"""
1616

@@ -49,11 +49,11 @@
4949

5050
# %%
5151
# The uninformative feature with high cardinality is generated so that is independent of
52-
# the target variable. We will show that target encoding without cross validation will
52+
# the target variable. We will show that target encoding without cross fitting will
5353
# cause catastrophic overfitting for the downstream regressor. These high cardinality
5454
# features are basically unique identifiers for samples which should generally be
5555
# removed from machine learning dataset. In this example, we generate them to show how
56-
# :class:`TargetEncoder`'s default cross validation behavior mitigates the overfitting
56+
# :class:`TargetEncoder`'s default cross fitting behavior mitigates the overfitting
5757
# issue automatically.
5858
X_near_unique_categories = rng.choice(
5959
int(0.9 * n_samples), size=n_samples, replace=True
@@ -79,7 +79,7 @@
7979
# ==========================
8080
# In this section, we train a ridge regressor on the dataset with and without
8181
# encoding and explore the influence of target encoder with and without the
82-
# interval cross validation. First, we see the Ridge model trained on the
82+
# internal cross fitting. First, we see the Ridge model trained on the
8383
# raw features will have low performance, because the order of the informative
8484
# feature is not informative:
8585
import sklearn
@@ -96,7 +96,7 @@
9696

9797
# %%
9898
# Next, we create a pipeline with the target encoder and ridge model. The pipeline
99-
# uses :meth:`TargetEncoder.fit_transform` which uses cross validation. We see that
99+
# uses :meth:`TargetEncoder.fit_transform` which uses cross fitting. We see that
100100
# the model fits the data well and generalizes to the test set:
101101
from sklearn.pipeline import make_pipeline
102102
from sklearn.preprocessing import TargetEncoder
@@ -120,11 +120,11 @@
120120
_ = coefs_cv.plot(kind="barh")
121121

122122
# %%
123-
# While :meth:`TargetEncoder.fit_transform` uses an interval cross validation,
124-
# :meth:`TargetEncoder.transform` itself does not perform any cross validation.
123+
# While :meth:`TargetEncoder.fit_transform` uses an internal cross fitting scheme,
124+
# :meth:`TargetEncoder.transform` itself does not perform any cross fitting.
125125
# It uses the aggregation of the complete training set to transform the categorical
126126
# features. Thus, we can use :meth:`TargetEncoder.fit` followed by
127-
# :meth:`TargetEncoder.transform` to disable the cross validation. This encoding
127+
# :meth:`TargetEncoder.transform` to disable the cross fitting. This encoding
128128
# is then passed to the ridge model.
129129
target_encoder = TargetEncoder(random_state=0)
130130
target_encoder.fit(X_train, y_train)
@@ -154,8 +154,8 @@
154154
# %%
155155
# Conclusion
156156
# ==========
157-
# This example demonstrates the importance of :class:`TargetEncoder`'s interval cross
158-
# validation. It is important to use :meth:`TargetEncoder.fit_transform` to encode
157+
# This example demonstrates the importance of :class:`TargetEncoder`'s internal cross
158+
# fitting. It is important to use :meth:`TargetEncoder.fit_transform` to encode
159159
# training data before passing it to a machine learning model. When a
160160
# :class:`TargetEncoder` is a part of a :class:`~sklearn.pipeline.Pipeline` and the
161161
# pipeline is fitted, the pipeline will correctly call

sklearn/preprocessing/_target_encoder.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ class TargetEncoder(OneToOneFeatureMixin, _BaseEncoder):
2727
2828
.. note::
2929
`fit(X, y).transform(X)` does not equal `fit_transform(X, y)` because a
30-
cross-validation scheme is used in `fit_transform` for encoding. See the
30+
cross fitting scheme is used in `fit_transform` for encoding. See the
3131
:ref:`User Guide <target_encoder>` for details.
3232
3333
.. versionadded:: 1.3
@@ -68,7 +68,7 @@ class TargetEncoder(OneToOneFeatureMixin, _BaseEncoder):
6868
If `"auto"`, then `smooth` is set to an empirical Bayes estimate.
6969
7070
cv : int, default=5
71-
Determines the number of folds in the cross-validation strategy used in
71+
Determines the number of folds in the cross fitting strategy used in
7272
:meth:`fit_transform`. For classification targets, `StratifiedKFold` is used
7373
and for continuous targets, `KFold` is used.
7474
@@ -204,7 +204,7 @@ def fit_transform(self, X, y):
204204
205205
.. note::
206206
`fit(X, y).transform(X)` does not equal `fit_transform(X, y)` because a
207-
cross-validation scheme is used in `fit_transform` for encoding. See the
207+
cross fitting scheme is used in `fit_transform` for encoding. See the
208208
:ref:`User Guide <target_encoder>`. for details.
209209
210210
Parameters
@@ -260,7 +260,7 @@ def transform(self, X):
260260
261261
.. note::
262262
`fit(X, y).transform(X)` does not equal `fit_transform(X, y)` because a
263-
cross-validation scheme is used in `fit_transform` for encoding. See the
263+
cross fitting scheme is used in `fit_transform` for encoding. See the
264264
:ref:`User Guide <target_encoder>`. for details.
265265
266266
Parameters

0 commit comments

Comments
 (0)