Skip to content

Commit 4551602

Browse files
committed
Fix partial fit
Signed-off-by: Adam Li <adam2392@gmail.com>
1 parent 960b589 commit 4551602

File tree

2 files changed

+24
-4
lines changed

2 files changed

+24
-4
lines changed

sklearn/ensemble/_forest.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1254,7 +1254,15 @@ def partial_fit(self, X, y, sample_weight=None, classes=None):
12541254
self : object
12551255
Returns the instance itself.
12561256
"""
1257-
self._validate_params()
1257+
X, y = validate_data(
1258+
self,
1259+
X,
1260+
y,
1261+
multi_output=True,
1262+
accept_sparse="csc",
1263+
dtype=DTYPE,
1264+
ensure_all_finite=False,
1265+
)
12581266

12591267
# validate input parameters
12601268
first_call = _check_partial_fit_first_call(self, classes=classes)

sklearn/tree/_classes.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -252,7 +252,7 @@ def _fit(
252252
dtype=DTYPE, accept_sparse="csc", ensure_all_finite=False
253253
)
254254
check_y_params = dict(ensure_2d=False, dtype=None)
255-
if y is not None or self.__sklearn_tags__().requires_y:
255+
if y is not None or self.__sklearn_tags__().required:
256256
X, y = validate_data(
257257
self, X, y, validate_separately=(check_X_params, check_y_params)
258258
)
@@ -1375,7 +1375,15 @@ def partial_fit(self, X, y, sample_weight=None, check_input=True, classes=None):
13751375
self : DecisionTreeClassifier
13761376
Fitted estimator.
13771377
"""
1378-
self._validate_params()
1378+
X, y = validate_data(
1379+
self,
1380+
X,
1381+
y,
1382+
multi_output=True,
1383+
accept_sparse="csc",
1384+
dtype=DTYPE,
1385+
ensure_all_finite=False,
1386+
)
13791387

13801388
# validate input parameters
13811389
first_call = _check_partial_fit_first_call(self, classes=classes)
@@ -1398,7 +1406,11 @@ def partial_fit(self, X, y, sample_weight=None, check_input=True, classes=None):
13981406
check_X_params = dict(dtype=DTYPE, accept_sparse="csc")
13991407
check_y_params = dict(ensure_2d=False, dtype=None)
14001408
X, y = validate_data(
1401-
self, X, y, reset=False, validate_separately=(check_X_params, check_y_params)
1409+
self,
1410+
X,
1411+
y,
1412+
reset=False,
1413+
validate_separately=(check_X_params, check_y_params),
14021414
)
14031415
if issparse(X):
14041416
X.sort_indices()

0 commit comments

Comments
 (0)