11
11
# Joly Arnaud <arnaud.v.joly@gmail.com>
12
12
# Fares Hedayati <fares.hedayati@gmail.com>
13
13
# Nelson Liu <nelson@nelsonliu.me>
14
+ # Haoyin Xu <haoyinxu@gmail.com>
14
15
#
15
16
# License: BSD 3 clause
16
17
17
18
import copy
18
19
import numbers
19
- import warnings
20
20
from abc import ABCMeta , abstractmethod
21
21
from math import ceil
22
22
from numbers import Integral , Real
35
35
)
36
36
from sklearn .utils import Bunch , check_random_state , compute_sample_weight
37
37
from sklearn .utils ._param_validation import Hidden , Interval , RealNotInt , StrOptions
38
- from sklearn .utils .multiclass import check_classification_targets
38
+ from sklearn .utils .multiclass import (
39
+ _check_partial_fit_first_call ,
40
+ check_classification_targets ,
41
+ )
39
42
from sklearn .utils .validation import (
40
43
_assert_all_finite_element_wise ,
41
44
_check_sample_weight ,
@@ -237,6 +240,7 @@ def _fit(
237
240
self ,
238
241
X ,
239
242
y ,
243
+ classes = None ,
240
244
sample_weight = None ,
241
245
check_input = True ,
242
246
missing_values_in_feature_mask = None ,
@@ -291,7 +295,6 @@ def _fit(
291
295
is_classification = False
292
296
if y is not None :
293
297
is_classification = is_classifier (self )
294
-
295
298
y = np .atleast_1d (y )
296
299
expanded_class_weight = None
297
300
@@ -313,10 +316,28 @@ def _fit(
313
316
y_original = np .copy (y )
314
317
315
318
y_encoded = np .zeros (y .shape , dtype = int )
316
- for k in range (self .n_outputs_ ):
317
- classes_k , y_encoded [:, k ] = np .unique (y [:, k ], return_inverse = True )
318
- self .classes_ .append (classes_k )
319
- self .n_classes_ .append (classes_k .shape [0 ])
319
+ if classes is not None :
320
+ classes = np .atleast_1d (classes )
321
+ if classes .ndim == 1 :
322
+ classes = np .array ([classes ])
323
+
324
+ for k in classes :
325
+ self .classes_ .append (np .array (k ))
326
+ self .n_classes_ .append (np .array (k ).shape [0 ])
327
+
328
+ for i in range (n_samples ):
329
+ for j in range (self .n_outputs_ ):
330
+ y_encoded [i , j ] = np .where (self .classes_ [j ] == y [i , j ])[0 ][
331
+ 0
332
+ ]
333
+ else :
334
+ for k in range (self .n_outputs_ ):
335
+ classes_k , y_encoded [:, k ] = np .unique (
336
+ y [:, k ], return_inverse = True
337
+ )
338
+ self .classes_ .append (classes_k )
339
+ self .n_classes_ .append (classes_k .shape [0 ])
340
+
320
341
y = y_encoded
321
342
322
343
if self .class_weight is not None :
@@ -355,24 +376,8 @@ def _fit(
355
376
if self .max_features == "auto" :
356
377
if is_classification :
357
378
max_features = max (1 , int (np .sqrt (self .n_features_in_ )))
358
- warnings .warn (
359
- (
360
- "`max_features='auto'` has been deprecated in 1.1 "
361
- "and will be removed in 1.3. To keep the past behaviour, "
362
- "explicitly set `max_features='sqrt'`."
363
- ),
364
- FutureWarning ,
365
- )
366
379
else :
367
380
max_features = self .n_features_in_
368
- warnings .warn (
369
- (
370
- "`max_features='auto'` has been deprecated in 1.1 "
371
- "and will be removed in 1.3. To keep the past behaviour, "
372
- "explicitly set `max_features=1.0'`."
373
- ),
374
- FutureWarning ,
375
- )
376
381
elif self .max_features == "sqrt" :
377
382
max_features = max (1 , int (np .sqrt (self .n_features_in_ )))
378
383
elif self .max_features == "log2" :
@@ -538,7 +543,7 @@ def _build_tree(
538
543
539
544
# Use BestFirst if max_leaf_nodes given; use DepthFirst otherwise
540
545
if max_leaf_nodes < 0 :
541
- builder = DepthFirstTreeBuilder (
546
+ self . builder_ = DepthFirstTreeBuilder (
542
547
splitter ,
543
548
min_samples_split ,
544
549
min_samples_leaf ,
@@ -548,7 +553,7 @@ def _build_tree(
548
553
self .store_leaf_values ,
549
554
)
550
555
else :
551
- builder = BestFirstTreeBuilder (
556
+ self . builder_ = BestFirstTreeBuilder (
552
557
splitter ,
553
558
min_samples_split ,
554
559
min_samples_leaf ,
@@ -558,7 +563,9 @@ def _build_tree(
558
563
self .min_impurity_decrease ,
559
564
self .store_leaf_values ,
560
565
)
561
- builder .build (self .tree_ , X , y , sample_weight , missing_values_in_feature_mask )
566
+ self .builder_ .build (
567
+ self .tree_ , X , y , sample_weight , missing_values_in_feature_mask
568
+ )
562
569
563
570
if self .n_outputs_ == 1 and is_classifier (self ):
564
571
self .n_classes_ = self .n_classes_ [0 ]
@@ -1119,6 +1126,9 @@ class DecisionTreeClassifier(ClassifierMixin, BaseDecisionTree):
1119
1126
:ref:`sphx_glr_auto_examples_tree_plot_unveil_tree_structure.py`
1120
1127
for basic usage of these attributes.
1121
1128
1129
+ builder_ : TreeBuilder instance
1130
+ The underlying TreeBuilder object.
1131
+
1122
1132
See Also
1123
1133
--------
1124
1134
DecisionTreeRegressor : A decision tree regressor.
@@ -1209,7 +1219,14 @@ def __init__(
1209
1219
)
1210
1220
1211
1221
@_fit_context (prefer_skip_nested_validation = True )
1212
- def fit (self , X , y , sample_weight = None , check_input = True ):
1222
+ def fit (
1223
+ self ,
1224
+ X ,
1225
+ y ,
1226
+ sample_weight = None ,
1227
+ check_input = True ,
1228
+ classes = None ,
1229
+ ):
1213
1230
"""Build a decision tree classifier from the training set (X, y).
1214
1231
1215
1232
Parameters
@@ -1233,6 +1250,11 @@ def fit(self, X, y, sample_weight=None, check_input=True):
1233
1250
Allow to bypass several input checking.
1234
1251
Don't use this parameter unless you know what you're doing.
1235
1252
1253
+ classes : array-like of shape (n_classes,), default=None
1254
+ List of all the classes that can possibly appear in the y vector.
1255
+ Must be provided at the first call to partial_fit, can be omitted
1256
+ in subsequent calls.
1257
+
1236
1258
Returns
1237
1259
-------
1238
1260
self : DecisionTreeClassifier
@@ -1243,9 +1265,112 @@ def fit(self, X, y, sample_weight=None, check_input=True):
1243
1265
y ,
1244
1266
sample_weight = sample_weight ,
1245
1267
check_input = check_input ,
1268
+ classes = classes ,
1246
1269
)
1247
1270
return self
1248
1271
1272
+ def partial_fit (self , X , y , classes = None , sample_weight = None , check_input = True ):
1273
+ """Update a decision tree classifier from the training set (X, y).
1274
+
1275
+ Parameters
1276
+ ----------
1277
+ X : {array-like, sparse matrix} of shape (n_samples, n_features)
1278
+ The training input samples. Internally, it will be converted to
1279
+ ``dtype=np.float32`` and if a sparse matrix is provided
1280
+ to a sparse ``csc_matrix``.
1281
+
1282
+ y : array-like of shape (n_samples,) or (n_samples, n_outputs)
1283
+ The target values (class labels) as integers or strings.
1284
+
1285
+ classes : array-like of shape (n_classes,), default=None
1286
+ List of all the classes that can possibly appear in the y vector.
1287
+ Must be provided at the first call to partial_fit, can be omitted
1288
+ in subsequent calls.
1289
+
1290
+ sample_weight : array-like of shape (n_samples,), default=None
1291
+ Sample weights. If None, then samples are equally weighted. Splits
1292
+ that would create child nodes with net zero or negative weight are
1293
+ ignored while searching for a split in each node. Splits are also
1294
+ ignored if they would result in any single class carrying a
1295
+ negative weight in either child node.
1296
+
1297
+ check_input : bool, default=True
1298
+ Allow to bypass several input checking.
1299
+ Don't use this parameter unless you know what you do.
1300
+
1301
+ Returns
1302
+ -------
1303
+ self : DecisionTreeClassifier
1304
+ Fitted estimator.
1305
+ """
1306
+ self ._validate_params ()
1307
+
1308
+ # validate input parameters
1309
+ first_call = _check_partial_fit_first_call (self , classes = classes )
1310
+
1311
+ # Fit if no tree exists yet
1312
+ if first_call :
1313
+ self .fit (
1314
+ X ,
1315
+ y ,
1316
+ sample_weight = sample_weight ,
1317
+ check_input = check_input ,
1318
+ classes = classes ,
1319
+ )
1320
+ return self
1321
+
1322
+ if check_input :
1323
+ # Need to validate separately here.
1324
+ # We can't pass multi_ouput=True because that would allow y to be
1325
+ # csr.
1326
+ check_X_params = dict (dtype = DTYPE , accept_sparse = "csc" )
1327
+ check_y_params = dict (ensure_2d = False , dtype = None )
1328
+ X , y = self ._validate_data (
1329
+ X , y , reset = False , validate_separately = (check_X_params , check_y_params )
1330
+ )
1331
+ if issparse (X ):
1332
+ X .sort_indices ()
1333
+
1334
+ if X .indices .dtype != np .intc or X .indptr .dtype != np .intc :
1335
+ raise ValueError (
1336
+ "No support for np.int64 index based sparse matrices"
1337
+ )
1338
+
1339
+ if X .shape [1 ] != self .n_features_in_ :
1340
+ msg = "Number of features %d does not match previous data %d."
1341
+ raise ValueError (msg % (X .shape [1 ], self .n_features_in_ ))
1342
+
1343
+ y = np .atleast_1d (y )
1344
+
1345
+ if y .ndim == 1 :
1346
+ # reshape is necessary to preserve the data contiguity against vs
1347
+ # [:, np.newaxis] that does not.
1348
+ y = np .reshape (y , (- 1 , 1 ))
1349
+
1350
+ check_classification_targets (y )
1351
+ y = np .copy (y )
1352
+
1353
+ classes = self .classes_
1354
+ if self .n_outputs_ == 1 :
1355
+ classes = [classes ]
1356
+
1357
+ y_encoded = np .zeros (y .shape , dtype = int )
1358
+ for i in range (X .shape [0 ]):
1359
+ for j in range (self .n_outputs_ ):
1360
+ y_encoded [i , j ] = np .where (classes [j ] == y [i , j ])[0 ][0 ]
1361
+ y = y_encoded
1362
+
1363
+ if getattr (y , "dtype" , None ) != DOUBLE or not y .flags .contiguous :
1364
+ y = np .ascontiguousarray (y , dtype = DOUBLE )
1365
+
1366
+ # Update tree
1367
+ self .builder_ .initialize_node_queue (self .tree_ , X , y , sample_weight )
1368
+ self .builder_ .build (self .tree_ , X , y , sample_weight )
1369
+
1370
+ self ._prune_tree ()
1371
+
1372
+ return self
1373
+
1249
1374
def predict_proba (self , X , check_input = True ):
1250
1375
"""Predict class probabilities of the input samples X.
1251
1376
@@ -1518,6 +1643,9 @@ class DecisionTreeRegressor(RegressorMixin, BaseDecisionTree):
1518
1643
:ref:`sphx_glr_auto_examples_tree_plot_unveil_tree_structure.py`
1519
1644
for basic usage of these attributes.
1520
1645
1646
+ builder_ : TreeBuilder instance
1647
+ The underlying TreeBuilder object.
1648
+
1521
1649
See Also
1522
1650
--------
1523
1651
DecisionTreeClassifier : A decision tree classifier.
@@ -1600,7 +1728,14 @@ def __init__(
1600
1728
)
1601
1729
1602
1730
@_fit_context (prefer_skip_nested_validation = True )
1603
- def fit (self , X , y , sample_weight = None , check_input = True ):
1731
+ def fit (
1732
+ self ,
1733
+ X ,
1734
+ y ,
1735
+ sample_weight = None ,
1736
+ check_input = True ,
1737
+ classes = None ,
1738
+ ):
1604
1739
"""Build a decision tree regressor from the training set (X, y).
1605
1740
1606
1741
Parameters
@@ -1623,6 +1758,9 @@ def fit(self, X, y, sample_weight=None, check_input=True):
1623
1758
Allow to bypass several input checking.
1624
1759
Don't use this parameter unless you know what you're doing.
1625
1760
1761
+ classes : array-like of shape (n_classes,), default=None
1762
+ List of all the classes that can possibly appear in the y vector.
1763
+
1626
1764
Returns
1627
1765
-------
1628
1766
self : DecisionTreeRegressor
@@ -1634,6 +1772,7 @@ def fit(self, X, y, sample_weight=None, check_input=True):
1634
1772
y ,
1635
1773
sample_weight = sample_weight ,
1636
1774
check_input = check_input ,
1775
+ classes = classes ,
1637
1776
)
1638
1777
return self
1639
1778
@@ -1885,6 +2024,9 @@ class ExtraTreeClassifier(DecisionTreeClassifier):
1885
2024
:ref:`sphx_glr_auto_examples_tree_plot_unveil_tree_structure.py`
1886
2025
for basic usage of these attributes.
1887
2026
2027
+ builder_ : TreeBuilder instance
2028
+ The underlying TreeBuilder object.
2029
+
1888
2030
See Also
1889
2031
--------
1890
2032
ExtraTreeRegressor : An extremely randomized tree regressor.
@@ -2147,6 +2289,9 @@ class ExtraTreeRegressor(DecisionTreeRegressor):
2147
2289
:ref:`sphx_glr_auto_examples_tree_plot_unveil_tree_structure.py`
2148
2290
for basic usage of these attributes.
2149
2291
2292
+ builder_ : TreeBuilder instance
2293
+ The underlying TreeBuilder object.
2294
+
2150
2295
See Also
2151
2296
--------
2152
2297
ExtraTreeClassifier : An extremely randomized tree classifier.
0 commit comments