Skip to content

Commit f3951d4

Browse files
authored
Improve xgboost test. (#949)
- Add assertion. - Stopped using deprecated label encoder (removes the following warning: https://paste.googleplex.com/6273615185051648)
1 parent eadd3e1 commit f3951d4

File tree

1 file changed

+12
-5
lines changed

1 file changed

+12
-5
lines changed

tests/test_xgboost.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
import unittest
22

3+
import numpy as np
34
import xgboost
45

56
from distutils.version import StrictVersion
6-
from sklearn import datasets
77
from xgboost import XGBClassifier
88

99
class TestXGBoost(unittest.TestCase):
@@ -12,8 +12,15 @@ def test_version(self):
1212
self.assertGreaterEqual(StrictVersion(xgboost.__version__), StrictVersion("1.2.1"))
1313

1414
def test_classifier(self):
15-
boston = datasets.load_boston()
16-
X, y = boston.data, boston.target
15+
X_train = np.random.random((100, 28))
16+
y_train = np.random.randint(10, size=(100, 1))
17+
X_test = np.random.random((100, 28))
18+
y_test = np.random.randint(10, size=(100, 1))
1719

18-
xgb1 = XGBClassifier(n_estimators=3)
19-
xgb1.fit(X[0:70],y[0:70])
20+
xgb1 = XGBClassifier(n_estimators=3, use_label_encoder=False)
21+
xgb1.fit(
22+
X_train, y_train,
23+
eval_set=[(X_train, y_train), (X_test, y_test)],
24+
eval_metric='mlogloss',
25+
)
26+
self.assertIn("validation_0", xgb1.evals_result())

0 commit comments

Comments
 (0)