Skip to content

Commit 351f14d

Browse files
adam2392PSSF23
andcommitted
[ENH v2] Add partial fit to the correct branch for decisiontreeclassifier (#54)
Supersedes: #50 Implements partial_fit API for all classification decision trees. --------- Signed-off-by: Adam Li <adam2392@gmail.com> Co-authored-by: Haoyin Xu <haoyinxu@gmail.com>
1 parent f042577 commit 351f14d

File tree

3 files changed

+506
-45
lines changed

3 files changed

+506
-45
lines changed

sklearn/tree/_classes.py

Lines changed: 173 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,12 @@
1111
# Joly Arnaud <arnaud.v.joly@gmail.com>
1212
# Fares Hedayati <fares.hedayati@gmail.com>
1313
# Nelson Liu <nelson@nelsonliu.me>
14+
# Haoyin Xu <haoyinxu@gmail.com>
1415
#
1516
# License: BSD 3 clause
1617

1718
import copy
1819
import numbers
19-
import warnings
2020
from abc import ABCMeta, abstractmethod
2121
from math import ceil
2222
from numbers import Integral, Real
@@ -35,7 +35,10 @@
3535
)
3636
from sklearn.utils import Bunch, check_random_state, compute_sample_weight
3737
from sklearn.utils._param_validation import Hidden, Interval, RealNotInt, StrOptions
38-
from sklearn.utils.multiclass import check_classification_targets
38+
from sklearn.utils.multiclass import (
39+
_check_partial_fit_first_call,
40+
check_classification_targets,
41+
)
3942
from sklearn.utils.validation import (
4043
_assert_all_finite_element_wise,
4144
_check_sample_weight,
@@ -237,6 +240,7 @@ def _fit(
237240
self,
238241
X,
239242
y,
243+
classes=None,
240244
sample_weight=None,
241245
check_input=True,
242246
missing_values_in_feature_mask=None,
@@ -291,7 +295,6 @@ def _fit(
291295
is_classification = False
292296
if y is not None:
293297
is_classification = is_classifier(self)
294-
295298
y = np.atleast_1d(y)
296299
expanded_class_weight = None
297300

@@ -313,10 +316,28 @@ def _fit(
313316
y_original = np.copy(y)
314317

315318
y_encoded = np.zeros(y.shape, dtype=int)
316-
for k in range(self.n_outputs_):
317-
classes_k, y_encoded[:, k] = np.unique(y[:, k], return_inverse=True)
318-
self.classes_.append(classes_k)
319-
self.n_classes_.append(classes_k.shape[0])
319+
if classes is not None:
320+
classes = np.atleast_1d(classes)
321+
if classes.ndim == 1:
322+
classes = np.array([classes])
323+
324+
for k in classes:
325+
self.classes_.append(np.array(k))
326+
self.n_classes_.append(np.array(k).shape[0])
327+
328+
for i in range(n_samples):
329+
for j in range(self.n_outputs_):
330+
y_encoded[i, j] = np.where(self.classes_[j] == y[i, j])[0][
331+
0
332+
]
333+
else:
334+
for k in range(self.n_outputs_):
335+
classes_k, y_encoded[:, k] = np.unique(
336+
y[:, k], return_inverse=True
337+
)
338+
self.classes_.append(classes_k)
339+
self.n_classes_.append(classes_k.shape[0])
340+
320341
y = y_encoded
321342

322343
if self.class_weight is not None:
@@ -355,24 +376,8 @@ def _fit(
355376
if self.max_features == "auto":
356377
if is_classification:
357378
max_features = max(1, int(np.sqrt(self.n_features_in_)))
358-
warnings.warn(
359-
(
360-
"`max_features='auto'` has been deprecated in 1.1 "
361-
"and will be removed in 1.3. To keep the past behaviour, "
362-
"explicitly set `max_features='sqrt'`."
363-
),
364-
FutureWarning,
365-
)
366379
else:
367380
max_features = self.n_features_in_
368-
warnings.warn(
369-
(
370-
"`max_features='auto'` has been deprecated in 1.1 "
371-
"and will be removed in 1.3. To keep the past behaviour, "
372-
"explicitly set `max_features=1.0'`."
373-
),
374-
FutureWarning,
375-
)
376381
elif self.max_features == "sqrt":
377382
max_features = max(1, int(np.sqrt(self.n_features_in_)))
378383
elif self.max_features == "log2":
@@ -538,7 +543,7 @@ def _build_tree(
538543

539544
# Use BestFirst if max_leaf_nodes given; use DepthFirst otherwise
540545
if max_leaf_nodes < 0:
541-
builder = DepthFirstTreeBuilder(
546+
self.builder_ = DepthFirstTreeBuilder(
542547
splitter,
543548
min_samples_split,
544549
min_samples_leaf,
@@ -548,7 +553,7 @@ def _build_tree(
548553
self.store_leaf_values,
549554
)
550555
else:
551-
builder = BestFirstTreeBuilder(
556+
self.builder_ = BestFirstTreeBuilder(
552557
splitter,
553558
min_samples_split,
554559
min_samples_leaf,
@@ -558,7 +563,9 @@ def _build_tree(
558563
self.min_impurity_decrease,
559564
self.store_leaf_values,
560565
)
561-
builder.build(self.tree_, X, y, sample_weight, missing_values_in_feature_mask)
566+
self.builder_.build(
567+
self.tree_, X, y, sample_weight, missing_values_in_feature_mask
568+
)
562569

563570
if self.n_outputs_ == 1 and is_classifier(self):
564571
self.n_classes_ = self.n_classes_[0]
@@ -1119,6 +1126,9 @@ class DecisionTreeClassifier(ClassifierMixin, BaseDecisionTree):
11191126
:ref:`sphx_glr_auto_examples_tree_plot_unveil_tree_structure.py`
11201127
for basic usage of these attributes.
11211128
1129+
builder_ : TreeBuilder instance
1130+
The underlying TreeBuilder object.
1131+
11221132
See Also
11231133
--------
11241134
DecisionTreeRegressor : A decision tree regressor.
@@ -1209,7 +1219,14 @@ def __init__(
12091219
)
12101220

12111221
@_fit_context(prefer_skip_nested_validation=True)
1212-
def fit(self, X, y, sample_weight=None, check_input=True):
1222+
def fit(
1223+
self,
1224+
X,
1225+
y,
1226+
sample_weight=None,
1227+
check_input=True,
1228+
classes=None,
1229+
):
12131230
"""Build a decision tree classifier from the training set (X, y).
12141231
12151232
Parameters
@@ -1233,6 +1250,11 @@ def fit(self, X, y, sample_weight=None, check_input=True):
12331250
Allow to bypass several input checking.
12341251
Don't use this parameter unless you know what you're doing.
12351252
1253+
classes : array-like of shape (n_classes,), default=None
1254+
List of all the classes that can possibly appear in the y vector.
1255+
Must be provided at the first call to partial_fit, can be omitted
1256+
in subsequent calls.
1257+
12361258
Returns
12371259
-------
12381260
self : DecisionTreeClassifier
@@ -1243,9 +1265,112 @@ def fit(self, X, y, sample_weight=None, check_input=True):
12431265
y,
12441266
sample_weight=sample_weight,
12451267
check_input=check_input,
1268+
classes=classes,
12461269
)
12471270
return self
12481271

1272+
def partial_fit(self, X, y, classes=None, sample_weight=None, check_input=True):
1273+
"""Update a decision tree classifier from the training set (X, y).
1274+
1275+
Parameters
1276+
----------
1277+
X : {array-like, sparse matrix} of shape (n_samples, n_features)
1278+
The training input samples. Internally, it will be converted to
1279+
``dtype=np.float32`` and if a sparse matrix is provided
1280+
to a sparse ``csc_matrix``.
1281+
1282+
y : array-like of shape (n_samples,) or (n_samples, n_outputs)
1283+
The target values (class labels) as integers or strings.
1284+
1285+
classes : array-like of shape (n_classes,), default=None
1286+
List of all the classes that can possibly appear in the y vector.
1287+
Must be provided at the first call to partial_fit, can be omitted
1288+
in subsequent calls.
1289+
1290+
sample_weight : array-like of shape (n_samples,), default=None
1291+
Sample weights. If None, then samples are equally weighted. Splits
1292+
that would create child nodes with net zero or negative weight are
1293+
ignored while searching for a split in each node. Splits are also
1294+
ignored if they would result in any single class carrying a
1295+
negative weight in either child node.
1296+
1297+
check_input : bool, default=True
1298+
Allow to bypass several input checking.
1299+
Don't use this parameter unless you know what you do.
1300+
1301+
Returns
1302+
-------
1303+
self : DecisionTreeClassifier
1304+
Fitted estimator.
1305+
"""
1306+
self._validate_params()
1307+
1308+
# validate input parameters
1309+
first_call = _check_partial_fit_first_call(self, classes=classes)
1310+
1311+
# Fit if no tree exists yet
1312+
if first_call:
1313+
self.fit(
1314+
X,
1315+
y,
1316+
sample_weight=sample_weight,
1317+
check_input=check_input,
1318+
classes=classes,
1319+
)
1320+
return self
1321+
1322+
if check_input:
1323+
# Need to validate separately here.
1324+
# We can't pass multi_ouput=True because that would allow y to be
1325+
# csr.
1326+
check_X_params = dict(dtype=DTYPE, accept_sparse="csc")
1327+
check_y_params = dict(ensure_2d=False, dtype=None)
1328+
X, y = self._validate_data(
1329+
X, y, reset=False, validate_separately=(check_X_params, check_y_params)
1330+
)
1331+
if issparse(X):
1332+
X.sort_indices()
1333+
1334+
if X.indices.dtype != np.intc or X.indptr.dtype != np.intc:
1335+
raise ValueError(
1336+
"No support for np.int64 index based sparse matrices"
1337+
)
1338+
1339+
if X.shape[1] != self.n_features_in_:
1340+
msg = "Number of features %d does not match previous data %d."
1341+
raise ValueError(msg % (X.shape[1], self.n_features_in_))
1342+
1343+
y = np.atleast_1d(y)
1344+
1345+
if y.ndim == 1:
1346+
# reshape is necessary to preserve the data contiguity against vs
1347+
# [:, np.newaxis] that does not.
1348+
y = np.reshape(y, (-1, 1))
1349+
1350+
check_classification_targets(y)
1351+
y = np.copy(y)
1352+
1353+
classes = self.classes_
1354+
if self.n_outputs_ == 1:
1355+
classes = [classes]
1356+
1357+
y_encoded = np.zeros(y.shape, dtype=int)
1358+
for i in range(X.shape[0]):
1359+
for j in range(self.n_outputs_):
1360+
y_encoded[i, j] = np.where(classes[j] == y[i, j])[0][0]
1361+
y = y_encoded
1362+
1363+
if getattr(y, "dtype", None) != DOUBLE or not y.flags.contiguous:
1364+
y = np.ascontiguousarray(y, dtype=DOUBLE)
1365+
1366+
# Update tree
1367+
self.builder_.initialize_node_queue(self.tree_, X, y, sample_weight)
1368+
self.builder_.build(self.tree_, X, y, sample_weight)
1369+
1370+
self._prune_tree()
1371+
1372+
return self
1373+
12491374
def predict_proba(self, X, check_input=True):
12501375
"""Predict class probabilities of the input samples X.
12511376
@@ -1518,6 +1643,9 @@ class DecisionTreeRegressor(RegressorMixin, BaseDecisionTree):
15181643
:ref:`sphx_glr_auto_examples_tree_plot_unveil_tree_structure.py`
15191644
for basic usage of these attributes.
15201645
1646+
builder_ : TreeBuilder instance
1647+
The underlying TreeBuilder object.
1648+
15211649
See Also
15221650
--------
15231651
DecisionTreeClassifier : A decision tree classifier.
@@ -1600,7 +1728,14 @@ def __init__(
16001728
)
16011729

16021730
@_fit_context(prefer_skip_nested_validation=True)
1603-
def fit(self, X, y, sample_weight=None, check_input=True):
1731+
def fit(
1732+
self,
1733+
X,
1734+
y,
1735+
sample_weight=None,
1736+
check_input=True,
1737+
classes=None,
1738+
):
16041739
"""Build a decision tree regressor from the training set (X, y).
16051740
16061741
Parameters
@@ -1623,6 +1758,9 @@ def fit(self, X, y, sample_weight=None, check_input=True):
16231758
Allow to bypass several input checking.
16241759
Don't use this parameter unless you know what you're doing.
16251760
1761+
classes : array-like of shape (n_classes,), default=None
1762+
List of all the classes that can possibly appear in the y vector.
1763+
16261764
Returns
16271765
-------
16281766
self : DecisionTreeRegressor
@@ -1634,6 +1772,7 @@ def fit(self, X, y, sample_weight=None, check_input=True):
16341772
y,
16351773
sample_weight=sample_weight,
16361774
check_input=check_input,
1775+
classes=classes,
16371776
)
16381777
return self
16391778

@@ -1885,6 +2024,9 @@ class ExtraTreeClassifier(DecisionTreeClassifier):
18852024
:ref:`sphx_glr_auto_examples_tree_plot_unveil_tree_structure.py`
18862025
for basic usage of these attributes.
18872026
2027+
builder_ : TreeBuilder instance
2028+
The underlying TreeBuilder object.
2029+
18882030
See Also
18892031
--------
18902032
ExtraTreeRegressor : An extremely randomized tree regressor.
@@ -2147,6 +2289,9 @@ class ExtraTreeRegressor(DecisionTreeRegressor):
21472289
:ref:`sphx_glr_auto_examples_tree_plot_unveil_tree_structure.py`
21482290
for basic usage of these attributes.
21492291
2292+
builder_ : TreeBuilder instance
2293+
The underlying TreeBuilder object.
2294+
21502295
See Also
21512296
--------
21522297
ExtraTreeClassifier : An extremely randomized tree classifier.

0 commit comments

Comments
 (0)