Skip to content

support xgboost early stopping #424

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Aug 4, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions test/gbdt/test_gbdt.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,3 +86,30 @@ def test_gbdt_with_save_load(gbdt_cls, stypes, task_type_and_metric):
assert (0 <= score <= 1)
elif task_type == TaskType.MULTICLASS_CLASSIFICATION:
assert (0 <= score <= 1)


def test_xgboost_early_stopping_rounds():
dataset: Dataset = FakeDataset(
num_rows=30,
with_nan=True,
stypes=[stype.numerical, stype.categorical],
create_split=True,
task_type=TaskType.BINARY_CLASSIFICATION,
col_to_text_embedder_cfg=TextEmbedderConfig(
text_embedder=HashTextEmbedder(8)),
)
dataset.materialize()
gbdt = XGBoost(
task_type=TaskType.BINARY_CLASSIFICATION,
metric=Metric.ROCAUC,
)

gbdt.tune(
tf_train=dataset.tensor_frame,
tf_val=dataset.tensor_frame,
num_trials=2,
num_boost_round=100,
early_stopping_rounds=2,
)

assert gbdt.model.best_iteration is not None
30 changes: 22 additions & 8 deletions torch_frame/gbdt/tuned_xgboost.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ def objective(
dtrain: Any, # xgboost.DMatrix
dvalid: Any, # xgboost.DMatrix
num_boost_round: int,
early_stopping_rounds: int,
) -> float:
r"""Objective function to be optimized.

Expand All @@ -101,6 +102,8 @@ def objective(
dtrain (xgboost.DMatrix): Train data.
dvalid (xgboost.DMatrix): Validation data.
num_boost_round (int): Number of boosting round.
early_stopping_rounds (int): Number of early stopping
rounds.

Returns:
float: Best objective value. Root mean squared error for
Expand Down Expand Up @@ -174,10 +177,15 @@ def objective(

boost = xgboost.train(self.params, dtrain,
num_boost_round=num_boost_round,
early_stopping_rounds=50, verbose_eval=False,
evals=[(dvalid, 'validation')],
callbacks=[pruning_callback])
pred = boost.predict(dvalid)
early_stopping_rounds=early_stopping_rounds,
verbose_eval=False, evals=[
(dvalid, 'validation')
], callbacks=[pruning_callback])
if boost.best_iteration:
iteration_range = (0, boost.best_iteration + 1)
else:
iteration_range = None
pred = boost.predict(dvalid, iteration_range)
score = self.compute_metric(torch.from_numpy(dvalid.get_label()),
torch.from_numpy(pred))
return score
Expand All @@ -188,6 +196,7 @@ def _tune(
tf_val: TensorFrame,
num_trials: int,
num_boost_round: int = 2000,
early_stopping_rounds: int = 50,
):
import optuna
import xgboost
Expand All @@ -207,13 +216,14 @@ def _tune(
feature_types=val_feat_type,
enable_categorical=True)
study.optimize(
lambda trial: self.objective(trial, dtrain, dvalid, num_boost_round
), num_trials)
lambda trial: self.objective(
trial, dtrain, dvalid, num_boost_round, early_stopping_rounds),
num_trials)
self.params.update(study.best_params)

self.model = xgboost.train(self.params, dtrain,
num_boost_round=num_boost_round,
early_stopping_rounds=50,
early_stopping_rounds=early_stopping_rounds,
verbose_eval=False,
evals=[(dvalid, 'validation')])

Expand All @@ -225,7 +235,11 @@ def _predict(self, tf_test: TensorFrame) -> Tensor:
dtest = xgboost.DMatrix(test_feat, label=test_y,
feature_types=test_feat_type,
enable_categorical=True)
pred = self.model.predict(dtest)
if self.model.best_iteration is not None:
iteration_range = self.model.best_iteration
else:
iteration_range = None
pred = self.model.predict(dtest, iteration_range)
return torch.from_numpy(pred).to(device)

def _load(self, path: str) -> None:
Expand Down
Loading