@@ -110,6 +110,7 @@ class BaseDecisionTree(MultiOutputMixin, BaseEstimator, metaclass=ABCMeta):
110
110
"min_samples_split" : [
111
111
Interval (Integral , 2 , None , closed = "left" ),
112
112
Interval (RealNotInt , 0.0 , 1.0 , closed = "right" ),
113
+ StrOptions ({"sqrt" , "log2" }),
113
114
],
114
115
"min_samples_leaf" : [
115
116
Interval (Integral , 1 , None , closed = "left" ),
@@ -364,13 +365,18 @@ def _fit(
364
365
else : # float
365
366
min_samples_leaf = int (ceil (self .min_samples_leaf * n_samples ))
366
367
367
- if isinstance (self .min_samples_split , numbers .Integral ):
368
+ if isinstance (self .min_samples_split , str ):
369
+ if self .min_samples_split == "sqrt" :
370
+ min_samples_split = max (1 , int (np .sqrt (self .n_features_in_ )))
371
+ elif self .min_samples_split == "log2" :
372
+ min_samples_split = max (1 , int (np .log2 (self .n_features_in_ )))
373
+ elif isinstance (self .min_samples_split , numbers .Integral ):
368
374
min_samples_split = self .min_samples_split
369
375
else : # float
370
376
min_samples_split = int (ceil (self .min_samples_split * n_samples ))
371
377
min_samples_split = max (2 , min_samples_split )
372
-
373
378
min_samples_split = max (min_samples_split , 2 * min_samples_leaf )
379
+ self .min_samples_split_ = min_samples_split
374
380
375
381
if isinstance (self .max_features , str ):
376
382
if self .max_features == "auto" :
0 commit comments