Skip to content

Commit 46727ef

Browse files
authored
ENH add X_val and y_val to HGBT.fit (scikit-learn#27124)
1 parent 1eff92b commit 46727ef

File tree

3 files changed

+193
-18
lines changed

3 files changed

+193
-18
lines changed
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
- :class:`ensemble.HistGradientBoostingClassifier` and
2+
:class:`ensemble.HistGradientBoostingRegressor` allow for more control over the
3+
validation set used for early stopping. You can now pass data to be used for
4+
validation directly to `fit` via the arguments `X_val`, `y_val` and
5+
`sample_weight_val`.
6+
By :user:`Christian Lorentzen <lorentzenchr>`.

sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py

Lines changed: 92 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -421,8 +421,8 @@ def _check_categorical_features(self, X):
421421
)
422422

423423
n_features = X.shape[1]
424-
# At this point `_validate_data` was not called yet because we want to use the
425-
# dtypes are used to discover the categorical features. Thus `feature_names_in_`
424+
# At this point `validate_data` was not called yet because we use the original
425+
# dtypes to discover the categorical features. Thus `feature_names_in_`
426426
# is not defined yet.
427427
feature_names_in_ = getattr(X, "columns", None)
428428

@@ -508,7 +508,16 @@ def _check_interaction_cst(self, n_features):
508508
return constraints
509509

510510
@_fit_context(prefer_skip_nested_validation=True)
511-
def fit(self, X, y, sample_weight=None):
511+
def fit(
512+
self,
513+
X,
514+
y,
515+
sample_weight=None,
516+
*,
517+
X_val=None,
518+
y_val=None,
519+
sample_weight_val=None,
520+
):
512521
"""Fit the gradient boosting model.
513522
514523
Parameters
@@ -524,6 +533,23 @@ def fit(self, X, y, sample_weight=None):
524533
525534
.. versionadded:: 0.23
526535
536+
X_val : array-like of shape (n_val, n_features)
537+
Additional sample of features for validation used in early stopping.
538+
In a `Pipeline`, `X_val` can be transformed the same way as `X` with
539+
`Pipeline(..., transform_input=["X_val"])`.
540+
541+
.. versionadded:: 1.7
542+
543+
y_val : array-like of shape (n_samples,)
544+
Additional sample of target values for validation used in early stopping.
545+
546+
.. versionadded:: 1.7
547+
548+
sample_weight_val : array-like of shape (n_samples,) default=None
549+
Additional weights for validation used in early stopping.
550+
551+
.. versionadded:: 1.7
552+
527553
Returns
528554
-------
529555
self : object
@@ -548,6 +574,30 @@ def fit(self, X, y, sample_weight=None):
548574

549575
sample_weight = self._finalize_sample_weight(sample_weight, y)
550576

577+
validation_data_provided = X_val is not None or y_val is not None
578+
if validation_data_provided:
579+
if y_val is None:
580+
raise ValueError("X_val is provided, but y_val was not provided.")
581+
if X_val is None:
582+
raise ValueError("y_val is provided, but X_val was not provided.")
583+
X_val = self._preprocess_X(X_val, reset=False)
584+
y_val = _check_y(y_val, estimator=self)
585+
y_val = self._encode_y_val(y_val)
586+
check_consistent_length(X_val, y_val)
587+
if sample_weight_val is not None:
588+
sample_weight_val = _check_sample_weight(
589+
sample_weight_val, X_val, dtype=np.float64
590+
)
591+
if self.early_stopping is False:
592+
raise ValueError(
593+
"X_val and y_val are passed to fit while at the same time "
594+
"early_stopping is False. When passing X_val and y_val to fit,"
595+
"early_stopping should be set to either 'auto' or True."
596+
)
597+
598+
# Note: At this point, we could delete self._label_encoder if it exists.
599+
# But we don't to keep the code even simpler.
600+
551601
rng = check_random_state(self.random_state)
552602

553603
# When warm starting, we want to reuse the same seed that was used
@@ -598,13 +648,19 @@ def fit(self, X, y, sample_weight=None):
598648
self._loss = self.loss
599649

600650
if self.early_stopping == "auto":
601-
self.do_early_stopping_ = n_samples > 10000
651+
self.do_early_stopping_ = n_samples > 10_000
602652
else:
603653
self.do_early_stopping_ = self.early_stopping
604654

605655
# create validation data if needed
606-
self._use_validation_data = self.validation_fraction is not None
607-
if self.do_early_stopping_ and self._use_validation_data:
656+
self._use_validation_data = (
657+
self.validation_fraction is not None or validation_data_provided
658+
)
659+
if (
660+
self.do_early_stopping_
661+
and self._use_validation_data
662+
and not validation_data_provided
663+
):
608664
# stratify for classification
609665
# instead of checking predict_proba, loss.n_classes >= 2 would also work
610666
stratify = y if hasattr(self._loss, "predict_proba") else None
@@ -642,7 +698,8 @@ def fit(self, X, y, sample_weight=None):
642698
)
643699
else:
644700
X_train, y_train, sample_weight_train = X, y, sample_weight
645-
X_val = y_val = sample_weight_val = None
701+
if not validation_data_provided:
702+
X_val = y_val = sample_weight_val = None
646703

647704
# Bin the data
648705
# For ease of use of the API, the user-facing GBDT classes accept the
@@ -1397,7 +1454,11 @@ def _get_loss(self, sample_weight):
13971454

13981455
@abstractmethod
13991456
def _encode_y(self, y=None):
1400-
pass
1457+
pass # pragma: no cover
1458+
1459+
@abstractmethod
1460+
def _encode_y_val(self, y=None):
1461+
pass # pragma: no cover
14011462

14021463
@property
14031464
def n_iter_(self):
@@ -1574,8 +1635,8 @@ class HistGradientBoostingRegressor(RegressorMixin, BaseHistGradientBoosting):
15741635
See :term:`the Glossary <warm_start>`.
15751636
early_stopping : 'auto' or bool, default='auto'
15761637
If 'auto', early stopping is enabled if the sample size is larger than
1577-
10000. If True, early stopping is enabled, otherwise early stopping is
1578-
disabled.
1638+
10000 or if `X_val` and `y_val` are passed to `fit`. If True, early stopping
1639+
is enabled, otherwise early stopping is disabled.
15791640
15801641
.. versionadded:: 0.23
15811642
@@ -1593,7 +1654,9 @@ class HistGradientBoostingRegressor(RegressorMixin, BaseHistGradientBoosting):
15931654
validation_fraction : int or float or None, default=0.1
15941655
Proportion (or absolute size) of training data to set aside as
15951656
validation data for early stopping. If None, early stopping is done on
1596-
the training data. Only used if early stopping is performed.
1657+
the training data.
1658+
The value is ignored if either early stopping is not performed, e.g.
1659+
`early_stopping=False`, or if `X_val` and `y_val` are passed to fit.
15971660
n_iter_no_change : int, default=10
15981661
Used to determine when to "early stop". The fitting process is
15991662
stopped when none of the last ``n_iter_no_change`` scores are better
@@ -1795,6 +1858,9 @@ def _encode_y(self, y):
17951858
)
17961859
return y
17971860

1861+
def _encode_y_val(self, y=None):
1862+
return self._encode_y(y)
1863+
17981864
def _get_loss(self, sample_weight):
17991865
if self.loss == "quantile":
18001866
return _LOSSES[self.loss](
@@ -1963,8 +2029,8 @@ class HistGradientBoostingClassifier(ClassifierMixin, BaseHistGradientBoosting):
19632029
See :term:`the Glossary <warm_start>`.
19642030
early_stopping : 'auto' or bool, default='auto'
19652031
If 'auto', early stopping is enabled if the sample size is larger than
1966-
10000. If True, early stopping is enabled, otherwise early stopping is
1967-
disabled.
2032+
10000 or if `X_val` and `y_val` are passed to `fit`. If True, early stopping
2033+
is enabled, otherwise early stopping is disabled.
19682034
19692035
.. versionadded:: 0.23
19702036
@@ -1981,7 +2047,9 @@ class HistGradientBoostingClassifier(ClassifierMixin, BaseHistGradientBoosting):
19812047
validation_fraction : int or float or None, default=0.1
19822048
Proportion (or absolute size) of training data to set aside as
19832049
validation data for early stopping. If None, early stopping is done on
1984-
the training data. Only used if early stopping is performed.
2050+
the training data.
2051+
The value is ignored if either early stopping is not performed, e.g.
2052+
`early_stopping=False`, or if `X_val` and `y_val` are passed to fit.
19852053
n_iter_no_change : int, default=10
19862054
Used to determine when to "early stop". The fitting process is
19872055
stopped when none of the last ``n_iter_no_change`` scores are better
@@ -2272,20 +2340,27 @@ def staged_decision_function(self, X):
22722340
yield staged_decision
22732341

22742342
def _encode_y(self, y):
2343+
"""Create self._label_encoder and encode y correspondingly."""
22752344
# encode classes into 0 ... n_classes - 1 and sets attributes classes_
22762345
# and n_trees_per_iteration_
22772346
check_classification_targets(y)
22782347

2279-
label_encoder = LabelEncoder()
2280-
encoded_y = label_encoder.fit_transform(y)
2281-
self.classes_ = label_encoder.classes_
2348+
# We need to store the label encoder in case y_val needs to be label encoded,
2349+
# too.
2350+
self._label_encoder = LabelEncoder()
2351+
encoded_y = self._label_encoder.fit_transform(y)
2352+
self.classes_ = self._label_encoder.classes_
22822353
n_classes = self.classes_.shape[0]
22832354
# only 1 tree for binary classification. For multiclass classification,
22842355
# we build 1 tree per class.
22852356
self.n_trees_per_iteration_ = 1 if n_classes <= 2 else n_classes
22862357
encoded_y = encoded_y.astype(Y_DTYPE, copy=False)
22872358
return encoded_y
22882359

2360+
def _encode_y_val(self, y):
2361+
encoded_y = self._label_encoder.transform(y)
2362+
return encoded_y.astype(Y_DTYPE, copy=False)
2363+
22892364
def _get_loss(self, sample_weight):
22902365
# At this point self.loss == "log_loss"
22912366
if self.n_trees_per_iteration_ == 1:

sklearn/ensemble/_hist_gradient_boosting/tests/test_gradient_boosting.py

Lines changed: 95 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
from sklearn.model_selection import cross_val_score, train_test_split
3636
from sklearn.pipeline import make_pipeline
3737
from sklearn.preprocessing import KBinsDiscretizer, MinMaxScaler, OneHotEncoder
38-
from sklearn.utils import shuffle
38+
from sklearn.utils import check_random_state, shuffle
3939
from sklearn.utils._openmp_helpers import _openmp_effective_n_threads
4040
from sklearn.utils._testing import _convert_container
4141
from sklearn.utils.fixes import _IS_32BIT
@@ -1450,6 +1450,100 @@ def test_unknown_category_that_are_negative():
14501450
assert_allclose(hist.predict(X_test_neg), hist.predict(X_test_nan))
14511451

14521452

1453+
@pytest.mark.parametrize(
1454+
("GradientBoosting", "make_X_y"),
1455+
[
1456+
(HistGradientBoostingClassifier, make_classification),
1457+
(HistGradientBoostingRegressor, make_regression),
1458+
],
1459+
)
1460+
@pytest.mark.parametrize("sample_weight", [False, True])
1461+
def test_X_val_in_fit(GradientBoosting, make_X_y, sample_weight, global_random_seed):
1462+
"""Test that passing X_val, y_val in fit is same as validation fraction."""
1463+
rng = np.random.RandomState(42)
1464+
n_samples = 100
1465+
X, y = make_X_y(n_samples=n_samples, random_state=rng)
1466+
if sample_weight:
1467+
sample_weight = np.abs(rng.normal(size=n_samples))
1468+
data = (X, y, sample_weight)
1469+
else:
1470+
sample_weight = None
1471+
data = (X, y)
1472+
rng_seed = global_random_seed
1473+
1474+
# Fit with validation fraction and early stopping.
1475+
m1 = GradientBoosting(
1476+
early_stopping=True,
1477+
validation_fraction=0.5,
1478+
random_state=rng_seed,
1479+
)
1480+
m1.fit(X, y, sample_weight)
1481+
1482+
# Do train-test split ourselves.
1483+
rng = check_random_state(rng_seed)
1484+
# We do the same as in the fit method.
1485+
stratify = y if isinstance(m1, HistGradientBoostingClassifier) else None
1486+
random_seed = rng.randint(np.iinfo(np.uint32).max, dtype="u8")
1487+
X_train, X_val, y_train, y_val, *sw = train_test_split(
1488+
*data,
1489+
test_size=0.5,
1490+
stratify=stratify,
1491+
random_state=random_seed,
1492+
)
1493+
if sample_weight is not None:
1494+
sample_weight_train = sw[0]
1495+
sample_weight_val = sw[1]
1496+
else:
1497+
sample_weight_train = None
1498+
sample_weight_val = None
1499+
m2 = GradientBoosting(
1500+
early_stopping=True,
1501+
random_state=rng_seed,
1502+
)
1503+
m2.fit(
1504+
X_train,
1505+
y_train,
1506+
sample_weight=sample_weight_train,
1507+
X_val=X_val,
1508+
y_val=y_val,
1509+
sample_weight_val=sample_weight_val,
1510+
)
1511+
1512+
assert_allclose(m2.n_iter_, m1.n_iter_)
1513+
assert_allclose(m2.predict(X), m1.predict(X))
1514+
1515+
1516+
def test_X_val_raises_missing_y_val():
1517+
"""Test that an error is raised if X_val given but y_val None."""
1518+
X, y = make_classification(n_samples=4)
1519+
X, X_val = X[:2], X[2:]
1520+
y, y_val = y[:2], y[2:]
1521+
with pytest.raises(
1522+
ValueError,
1523+
match="X_val is provided, but y_val was not provided",
1524+
):
1525+
HistGradientBoostingClassifier().fit(X, y, X_val=X_val)
1526+
with pytest.raises(
1527+
ValueError,
1528+
match="y_val is provided, but X_val was not provided",
1529+
):
1530+
HistGradientBoostingClassifier().fit(X, y, y_val=y_val)
1531+
1532+
1533+
def test_X_val_raises_with_early_stopping_false():
1534+
"""Test that an error is raised if X_val given but early_stopping is False."""
1535+
X, y = make_regression(n_samples=4)
1536+
X, X_val = X[:2], X[2:]
1537+
y, y_val = y[:2], y[2:]
1538+
with pytest.raises(
1539+
ValueError,
1540+
match="X_val and y_val are passed to fit while at the same time",
1541+
):
1542+
HistGradientBoostingRegressor(early_stopping=False).fit(
1543+
X, y, X_val=X_val, y_val=y_val
1544+
)
1545+
1546+
14531547
@pytest.mark.parametrize("dataframe_lib", ["pandas", "polars"])
14541548
@pytest.mark.parametrize(
14551549
"HistGradientBoosting",

0 commit comments

Comments
 (0)