Skip to content

Commit 60e925d

Browse files
committed
Enable generalization of min_samples_split
Signed-off-by: Adam Li <adam2392@gmail.com>
1 parent f097d78 commit 60e925d

File tree

2 files changed

+10
-4
lines changed

2 files changed

+10
-4
lines changed

sklearn/tree/_classes.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@ class BaseDecisionTree(MultiOutputMixin, BaseEstimator, metaclass=ABCMeta):
110110
"min_samples_split": [
111111
Interval(Integral, 2, None, closed="left"),
112112
Interval(RealNotInt, 0.0, 1.0, closed="right"),
113+
StrOptions({"sqrt", "log2"}),
113114
],
114115
"min_samples_leaf": [
115116
Interval(Integral, 1, None, closed="left"),
@@ -364,13 +365,18 @@ def _fit(
364365
else: # float
365366
min_samples_leaf = int(ceil(self.min_samples_leaf * n_samples))
366367

367-
if isinstance(self.min_samples_split, numbers.Integral):
368+
if isinstance(self.min_samples_split, str):
369+
if self.min_samples_split == "sqrt":
370+
min_samples_split = max(1, int(np.sqrt(self.n_features_in_)))
371+
elif self.min_samples_split == "log2":
372+
min_samples_split = max(1, int(np.log2(self.n_features_in_)))
373+
elif isinstance(self.min_samples_split, numbers.Integral):
368374
min_samples_split = self.min_samples_split
369375
else: # float
370376
min_samples_split = int(ceil(self.min_samples_split * n_samples))
371377
min_samples_split = max(2, min_samples_split)
372-
373378
min_samples_split = max(min_samples_split, 2 * min_samples_leaf)
379+
self.min_samples_split_ = min_samples_split
374380

375381
if isinstance(self.max_features, str):
376382
if self.max_features == "auto":

sklearn/tree/tests/test_tree.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2413,8 +2413,8 @@ def test_min_sample_split_1_error(Tree):
24132413
# min_samples_split=1 is invalid
24142414
tree = Tree(min_samples_split=1)
24152415
msg = (
2416-
r"'min_samples_split' .* must be an int in the range \[2, inf\) "
2417-
r"or a float in the range \(0.0, 1.0\]"
2416+
r"'min_samples_split' .* must be an int in the range \[2, inf\)"
2417+
r".* a float in the range \(0.0, 1.0\]"
24182418
)
24192419
with pytest.raises(ValueError, match=msg):
24202420
tree.fit(X, y)

0 commit comments

Comments
 (0)