From e1abd4e557f7a98df7ecac627809fb0639f95c43 Mon Sep 17 00:00:00 2001 From: yiweny Date: Sun, 4 Aug 2024 19:01:46 +0000 Subject: [PATCH 1/3] support xgboost early stopping --- test/gbdt/test_gbdt.py | 27 +++++++++++++++++++++++++++ torch_frame/gbdt/tuned_xgboost.py | 30 ++++++++++++++++++++++-------- 2 files changed, 49 insertions(+), 8 deletions(-) diff --git a/test/gbdt/test_gbdt.py b/test/gbdt/test_gbdt.py index 42787976c..34e7a8b40 100644 --- a/test/gbdt/test_gbdt.py +++ b/test/gbdt/test_gbdt.py @@ -86,3 +86,30 @@ def test_gbdt_with_save_load(gbdt_cls, stypes, task_type_and_metric): assert (0 <= score <= 1) elif task_type == TaskType.MULTICLASS_CLASSIFICATION: assert (0 <= score <= 1) + + +def test_xgboost_early_stopping_rounds(): + dataset: Dataset = FakeDataset( + num_rows=30, + with_nan=True, + stypes=[stype.numerical, stype.categorical], + create_split=True, + task_type=TaskType.BINARY_CLASSIFICATION, + col_to_text_embedder_cfg=TextEmbedderConfig( + text_embedder=HashTextEmbedder(8)), + ) + dataset.materialize() + gbdt = XGBoost( + task_type=TaskType.BINARY_CLASSIFICATION, + metric=Metric.ROCAUC, + ) + + gbdt.tune( + tf_train=dataset.tensor_frame, + tf_val=dataset.tensor_frame, + num_trials=2, + num_boost_round=100, + early_stopping_rounds=2, + ) + + assert gbdt.model.best_iteration is not None diff --git a/torch_frame/gbdt/tuned_xgboost.py b/torch_frame/gbdt/tuned_xgboost.py index 9b939d324..1b13eaf1f 100644 --- a/torch_frame/gbdt/tuned_xgboost.py +++ b/torch_frame/gbdt/tuned_xgboost.py @@ -93,6 +93,7 @@ def objective( dtrain: Any, # xgboost.DMatrix dvalid: Any, # xgboost.DMatrix num_boost_round: int, + early_stopping_rounds: int, ) -> float: r"""Objective function to be optimized. @@ -101,6 +102,8 @@ def objective( dtrain (xgboost.DMatrix): Train data. dvalid (xgboost.DMatrix): Validation data. num_boost_round (int): Number of boosting round. + early_stopping_rounds (int): Number of early stopping + rounds. Returns: float: Best objective value. Root mean squared error for @@ -174,10 +177,15 @@ def objective( boost = xgboost.train(self.params, dtrain, num_boost_round=num_boost_round, - early_stopping_rounds=50, verbose_eval=False, - evals=[(dvalid, 'validation')], - callbacks=[pruning_callback]) - pred = boost.predict(dvalid) + early_stopping_rounds=early_stopping_rounds, + verbose_eval=False, evals=[ + (dvalid, 'validation') + ], callbacks=[pruning_callback]) + if boost.best_iteration: + iteration_range = (0, boost.best_iteration + 1) + else: + iteration_range = None + pred = boost.predict(dvalid, iteration_range) score = self.compute_metric(torch.from_numpy(dvalid.get_label()), torch.from_numpy(pred)) return score @@ -188,6 +196,7 @@ def _tune( tf_val: TensorFrame, num_trials: int, num_boost_round: int = 2000, + early_stopping_rounds: int = 50, ): import optuna import xgboost @@ -207,13 +216,14 @@ def _tune( feature_types=val_feat_type, enable_categorical=True) study.optimize( - lambda trial: self.objective(trial, dtrain, dvalid, num_boost_round - ), num_trials) + lambda trial: self.objective( + trial, dtrain, dvalid, num_boost_round, early_stopping_rounds), + num_trials) self.params.update(study.best_params) self.model = xgboost.train(self.params, dtrain, num_boost_round=num_boost_round, - early_stopping_rounds=50, + early_stopping_rounds=early_stopping_rounds, verbose_eval=False, evals=[(dvalid, 'validation')]) @@ -225,7 +235,11 @@ def _predict(self, tf_test: TensorFrame) -> Tensor: dtest = xgboost.DMatrix(test_feat, label=test_y, feature_types=test_feat_type, enable_categorical=True) - pred = self.model.predict(dtest) + if self.model.best_iteration is not None: + iteration_range = self.model.best_iteration + else: + iteration_range = None + pred = self.model.predict(dtest, iteration_range) return torch.from_numpy(pred).to(device) def _load(self, path: str) -> None: From 12a188c2ace0a28ab70f6064fabb4d44a8d34352 Mon Sep 17 00:00:00 2001 From: yiweny Date: Sun, 4 Aug 2024 21:04:07 +0000 Subject: [PATCH 2/3] fix test cases --- test/gbdt/test_gbdt.py | 45 +++++++++---------------------- torch_frame/gbdt/tuned_xgboost.py | 18 ++++++++++--- 2 files changed, 27 insertions(+), 36 deletions(-) diff --git a/test/gbdt/test_gbdt.py b/test/gbdt/test_gbdt.py index 34e7a8b40..b216f04c5 100644 --- a/test/gbdt/test_gbdt.py +++ b/test/gbdt/test_gbdt.py @@ -54,12 +54,18 @@ def test_gbdt_with_save_load(gbdt_cls, stypes, task_type_and_metric): with pytest.raises(RuntimeError, match="is not yet fitted"): gbdt.save(path) - gbdt.tune( - tf_train=dataset.tensor_frame, - tf_val=dataset.tensor_frame, - num_trials=2, - num_boost_round=2, - ) + if isinstance(gbdt_cls, XGBoost): + gbdt.tune(tf_train=dataset.tensor_frame, + tf_val=dataset.tensor_frame, num_trials=2, + num_boost_round=1000, early_stopping_rounds=2) + assert gbdt.model.best_iteration is not None + else: + gbdt.tune( + tf_train=dataset.tensor_frame, + tf_val=dataset.tensor_frame, + num_trials=2, + num_boost_round=2, + ) gbdt.save(path) loaded_gbdt = gbdt_cls( @@ -86,30 +92,3 @@ def test_gbdt_with_save_load(gbdt_cls, stypes, task_type_and_metric): assert (0 <= score <= 1) elif task_type == TaskType.MULTICLASS_CLASSIFICATION: assert (0 <= score <= 1) - - -def test_xgboost_early_stopping_rounds(): - dataset: Dataset = FakeDataset( - num_rows=30, - with_nan=True, - stypes=[stype.numerical, stype.categorical], - create_split=True, - task_type=TaskType.BINARY_CLASSIFICATION, - col_to_text_embedder_cfg=TextEmbedderConfig( - text_embedder=HashTextEmbedder(8)), - ) - dataset.materialize() - gbdt = XGBoost( - task_type=TaskType.BINARY_CLASSIFICATION, - metric=Metric.ROCAUC, - ) - - gbdt.tune( - tf_train=dataset.tensor_frame, - tf_val=dataset.tensor_frame, - num_trials=2, - num_boost_round=100, - early_stopping_rounds=2, - ) - - assert gbdt.model.best_iteration is not None diff --git a/torch_frame/gbdt/tuned_xgboost.py b/torch_frame/gbdt/tuned_xgboost.py index 1b13eaf1f..7e50fed88 100644 --- a/torch_frame/gbdt/tuned_xgboost.py +++ b/torch_frame/gbdt/tuned_xgboost.py @@ -186,8 +186,13 @@ def objective( else: iteration_range = None pred = boost.predict(dvalid, iteration_range) - score = self.compute_metric(torch.from_numpy(dvalid.get_label()), - torch.from_numpy(pred)) + if (boost.best_iteration + and self.task_type == TaskType.MULTICLASS_CLASSIFICATION): + assert pred.shape[1] == self.params["num_class"] + pred = torch.argmax(torch.from_numpy(pred), dim=1) + else: + pred = torch.from_numpy(pred) + score = self.compute_metric(torch.from_numpy(dvalid.get_label()), pred) return score def _tune( @@ -240,7 +245,14 @@ def _predict(self, tf_test: TensorFrame) -> Tensor: else: iteration_range = None pred = self.model.predict(dtest, iteration_range) - return torch.from_numpy(pred).to(device) + + if (self.model.best_iteration + and self.task_type == TaskType.MULTICLASS_CLASSIFICATION): + assert pred.shape[1] == self._num_classes + pred = torch.argmax(torch.from_numpy(pred), dim=1) + else: + pred = torch.from_numpy(pred) + return pred.to(device) def _load(self, path: str) -> None: import xgboost From 36e913e0d2754526db8e802e7b7392f5aa3756aa Mon Sep 17 00:00:00 2001 From: yiweny Date: Sun, 4 Aug 2024 21:22:21 +0000 Subject: [PATCH 3/3] add comment --- torch_frame/gbdt/tuned_xgboost.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/torch_frame/gbdt/tuned_xgboost.py b/torch_frame/gbdt/tuned_xgboost.py index 7e50fed88..50ea2ce52 100644 --- a/torch_frame/gbdt/tuned_xgboost.py +++ b/torch_frame/gbdt/tuned_xgboost.py @@ -186,6 +186,10 @@ def objective( else: iteration_range = None pred = boost.predict(dvalid, iteration_range) + + # If xgboost early stops on multiclass classification + # task, then the output shape would be (batch_size, num_classes). + # We need to take argmax to get the final prediction output. if (boost.best_iteration and self.task_type == TaskType.MULTICLASS_CLASSIFICATION): assert pred.shape[1] == self.params["num_class"] @@ -246,6 +250,9 @@ def _predict(self, tf_test: TensorFrame) -> Tensor: iteration_range = None pred = self.model.predict(dtest, iteration_range) + # If xgboost early stops on multiclass classification + # task, then the output shape would be (batch_size, num_classes). + # We need to take argmax to get the final prediction output. if (self.model.best_iteration and self.task_type == TaskType.MULTICLASS_CLASSIFICATION): assert pred.shape[1] == self._num_classes