diff --git a/test/gbdt/test_gbdt.py b/test/gbdt/test_gbdt.py index 42787976c..b216f04c5 100644 --- a/test/gbdt/test_gbdt.py +++ b/test/gbdt/test_gbdt.py @@ -54,12 +54,18 @@ def test_gbdt_with_save_load(gbdt_cls, stypes, task_type_and_metric): with pytest.raises(RuntimeError, match="is not yet fitted"): gbdt.save(path) - gbdt.tune( - tf_train=dataset.tensor_frame, - tf_val=dataset.tensor_frame, - num_trials=2, - num_boost_round=2, - ) + if isinstance(gbdt_cls, XGBoost): + gbdt.tune(tf_train=dataset.tensor_frame, + tf_val=dataset.tensor_frame, num_trials=2, + num_boost_round=1000, early_stopping_rounds=2) + assert gbdt.model.best_iteration is not None + else: + gbdt.tune( + tf_train=dataset.tensor_frame, + tf_val=dataset.tensor_frame, + num_trials=2, + num_boost_round=2, + ) gbdt.save(path) loaded_gbdt = gbdt_cls( diff --git a/torch_frame/gbdt/tuned_xgboost.py b/torch_frame/gbdt/tuned_xgboost.py index 9b939d324..50ea2ce52 100644 --- a/torch_frame/gbdt/tuned_xgboost.py +++ b/torch_frame/gbdt/tuned_xgboost.py @@ -93,6 +93,7 @@ def objective( dtrain: Any, # xgboost.DMatrix dvalid: Any, # xgboost.DMatrix num_boost_round: int, + early_stopping_rounds: int, ) -> float: r"""Objective function to be optimized. @@ -101,6 +102,8 @@ def objective( dtrain (xgboost.DMatrix): Train data. dvalid (xgboost.DMatrix): Validation data. num_boost_round (int): Number of boosting round. + early_stopping_rounds (int): Number of early stopping + rounds. Returns: float: Best objective value. Root mean squared error for @@ -174,12 +177,26 @@ def objective( boost = xgboost.train(self.params, dtrain, num_boost_round=num_boost_round, - early_stopping_rounds=50, verbose_eval=False, - evals=[(dvalid, 'validation')], - callbacks=[pruning_callback]) - pred = boost.predict(dvalid) - score = self.compute_metric(torch.from_numpy(dvalid.get_label()), - torch.from_numpy(pred)) + early_stopping_rounds=early_stopping_rounds, + verbose_eval=False, evals=[ + (dvalid, 'validation') + ], callbacks=[pruning_callback]) + if boost.best_iteration: + iteration_range = (0, boost.best_iteration + 1) + else: + iteration_range = None + pred = boost.predict(dvalid, iteration_range) + + # If xgboost early stops on multiclass classification + # task, then the output shape would be (batch_size, num_classes). + # We need to take argmax to get the final prediction output. + if (boost.best_iteration + and self.task_type == TaskType.MULTICLASS_CLASSIFICATION): + assert pred.shape[1] == self.params["num_class"] + pred = torch.argmax(torch.from_numpy(pred), dim=1) + else: + pred = torch.from_numpy(pred) + score = self.compute_metric(torch.from_numpy(dvalid.get_label()), pred) return score def _tune( @@ -188,6 +205,7 @@ def _tune( tf_val: TensorFrame, num_trials: int, num_boost_round: int = 2000, + early_stopping_rounds: int = 50, ): import optuna import xgboost @@ -207,13 +225,14 @@ def _tune( feature_types=val_feat_type, enable_categorical=True) study.optimize( - lambda trial: self.objective(trial, dtrain, dvalid, num_boost_round - ), num_trials) + lambda trial: self.objective( + trial, dtrain, dvalid, num_boost_round, early_stopping_rounds), + num_trials) self.params.update(study.best_params) self.model = xgboost.train(self.params, dtrain, num_boost_round=num_boost_round, - early_stopping_rounds=50, + early_stopping_rounds=early_stopping_rounds, verbose_eval=False, evals=[(dvalid, 'validation')]) @@ -225,8 +244,22 @@ def _predict(self, tf_test: TensorFrame) -> Tensor: dtest = xgboost.DMatrix(test_feat, label=test_y, feature_types=test_feat_type, enable_categorical=True) - pred = self.model.predict(dtest) - return torch.from_numpy(pred).to(device) + if self.model.best_iteration is not None: + iteration_range = self.model.best_iteration + else: + iteration_range = None + pred = self.model.predict(dtest, iteration_range) + + # If xgboost early stops on multiclass classification + # task, then the output shape would be (batch_size, num_classes). + # We need to take argmax to get the final prediction output. + if (self.model.best_iteration + and self.task_type == TaskType.MULTICLASS_CLASSIFICATION): + assert pred.shape[1] == self._num_classes + pred = torch.argmax(torch.from_numpy(pred), dim=1) + else: + pred = torch.from_numpy(pred) + return pred.to(device) def _load(self, path: str) -> None: import xgboost