Skip to content

Commit e4b9728

Browse files
committed
Adding unit test for partial fit
Signed-off-by: Adam Li <adam2392@gmail.com>
1 parent 504770a commit e4b9728

File tree

2 files changed

+22
-3
lines changed

2 files changed

+22
-3
lines changed

sklearn/tree/_classes.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -252,7 +252,7 @@ def _fit(
252252
dtype=DTYPE, accept_sparse="csc", ensure_all_finite=False
253253
)
254254
check_y_params = dict(ensure_2d=False, dtype=None)
255-
if y is not None or self.__sklearn_tags__().required:
255+
if y is not None or self.__sklearn_tags__().target_tags.required:
256256
X, y = validate_data(
257257
self, X, y, validate_separately=(check_X_params, check_y_params)
258258
)
@@ -1398,7 +1398,11 @@ def partial_fit(self, X, y, sample_weight=None, check_input=True, classes=None):
13981398
check_X_params = dict(dtype=DTYPE, accept_sparse="csc")
13991399
check_y_params = dict(ensure_2d=False, dtype=None)
14001400
X, y = validate_data(
1401-
self, X, y, reset=False, validate_separately=(check_X_params, check_y_params)
1401+
self,
1402+
X,
1403+
y,
1404+
reset=False,
1405+
validate_separately=(check_X_params, check_y_params),
14021406
)
14031407
if issparse(X):
14041408
X.sort_indices()

sklearn/tree/tests/test_tree.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,10 @@
5454
ignore_warnings,
5555
skip_if_32bit,
5656
)
57-
from sklearn.utils.estimator_checks import check_sample_weights_invariance
57+
from sklearn.utils.estimator_checks import (
58+
check_sample_weights_invariance,
59+
parametrize_with_checks,
60+
)
5861
from sklearn.utils.fixes import (
5962
_IS_32BIT,
6063
COO_CONTAINERS,
@@ -235,6 +238,18 @@ def assert_tree_equal(d, s, message):
235238
)
236239

237240

241+
@parametrize_with_checks(
242+
[
243+
DecisionTreeClassifier(),
244+
DecisionTreeRegressor(),
245+
ExtraTreeClassifier(),
246+
ExtraTreeRegressor(),
247+
]
248+
)
249+
def test_sklearn_compatible_estimator(estimator, check):
250+
check(estimator)
251+
252+
238253
def test_classification_toy():
239254
# Check classification on a toy dataset.
240255
for name, Tree in CLF_TREES.items():

0 commit comments

Comments
 (0)