Skip to content

Commit 151cb2d

Browse files
FEA Add _build_pruned_tree to tree.pxd file to allow cimports and a _build_pruned_tree_py to allow anyone to prune trees (scikit-learn#29590)
Signed-off-by: Adam Li <adam2392@gmail.com> Co-authored-by: Thomas J. Fan <thomasjpfan@gmail.com>
1 parent 8392e92 commit 151cb2d

File tree

4 files changed

+108
-1
lines changed

4 files changed

+108
-1
lines changed

sklearn/tree/_tree.pxd

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,3 +114,20 @@ cdef class TreeBuilder:
114114
const float64_t[:, ::1] y,
115115
const float64_t[:] sample_weight,
116116
)
117+
118+
119+
# =============================================================================
120+
# Tree pruning
121+
# =============================================================================
122+
123+
# The private function allows any external caller to prune the tree and return
124+
# a new tree with the pruned nodes. The pruned tree is a new tree object.
125+
#
126+
# .. warning:: this function is not backwards compatible and may change without
127+
# notice.
128+
cdef void _build_pruned_tree(
129+
Tree tree, # OUT
130+
Tree orig_tree,
131+
const uint8_t[:] leaves_in_subtree,
132+
intp_t capacity
133+
)

sklearn/tree/_tree.pyx

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1872,7 +1872,7 @@ cdef struct BuildPrunedRecord:
18721872
intp_t parent
18731873
bint is_left
18741874

1875-
cdef _build_pruned_tree(
1875+
cdef void _build_pruned_tree(
18761876
Tree tree, # OUT
18771877
Tree orig_tree,
18781878
const uint8_t[:] leaves_in_subtree,
@@ -1931,6 +1931,15 @@ cdef _build_pruned_tree(
19311931
is_leaf = leaves_in_subtree[orig_node_id]
19321932
node = &orig_tree.nodes[orig_node_id]
19331933

1934+
# protect against an infinite loop as a runtime error, when leaves_in_subtree
1935+
# are improperly set where a node is not marked as a leaf, but is a node
1936+
# in the original tree. Thus, it violates the assumption that the node
1937+
# is a leaf in the pruned tree, or has a descendant that will be pruned.
1938+
if (not is_leaf and node.left_child == _TREE_LEAF
1939+
and node.right_child == _TREE_LEAF):
1940+
rc = -2
1941+
break
1942+
19341943
new_node_id = tree._add_node(
19351944
parent, is_left, is_leaf, node.feature, node.threshold,
19361945
node.impurity, node.n_node_samples,
@@ -1960,3 +1969,33 @@ cdef _build_pruned_tree(
19601969
tree.max_depth = max_depth_seen
19611970
if rc == -1:
19621971
raise MemoryError("pruning tree")
1972+
elif rc == -2:
1973+
raise ValueError(
1974+
"Node has reached a leaf in the original tree, but is not "
1975+
"marked as a leaf in the leaves_in_subtree mask."
1976+
)
1977+
1978+
1979+
def _build_pruned_tree_py(Tree tree, Tree orig_tree, const uint8_t[:] leaves_in_subtree):
1980+
"""Build a pruned tree.
1981+
1982+
Build a pruned tree from the original tree by transforming the nodes in
1983+
``leaves_in_subtree`` into leaves.
1984+
1985+
Parameters
1986+
----------
1987+
tree : Tree
1988+
Location to place the pruned tree
1989+
orig_tree : Tree
1990+
Original tree
1991+
leaves_in_subtree : uint8_t ndarray, shape=(node_count, )
1992+
Boolean mask for leaves to include in subtree. The array must have
1993+
the same size as the number of nodes in the original tree.
1994+
"""
1995+
if leaves_in_subtree.shape[0] != orig_tree.node_count:
1996+
raise ValueError(
1997+
f"The length of leaves_in_subtree {len(leaves_in_subtree)} must be "
1998+
f"equal to the number of nodes in the original tree {orig_tree.node_count}."
1999+
)
2000+
2001+
_build_pruned_tree(tree, orig_tree, leaves_in_subtree, orig_tree.node_count)

sklearn/tree/_utils.pxd

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ from ._tree cimport Node
88
from ..neighbors._quad_tree cimport Cell
99
from ..utils._typedefs cimport float32_t, float64_t, intp_t, uint8_t, int32_t, uint32_t
1010

11+
1112
cdef enum:
1213
# Max value for our rand_r replacement (near the bottom).
1314
# We don't use RAND_MAX because it's different across platforms and

sklearn/tree/tests/test_tree.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
NODE_DTYPE,
4040
TREE_LEAF,
4141
TREE_UNDEFINED,
42+
_build_pruned_tree_py,
4243
_check_n_classes,
4344
_check_node_ndarray,
4445
_check_value_ndarray,
@@ -2783,3 +2784,52 @@ def test_classification_tree_missing_values_toy():
27832784
(tree.tree_.children_left == -1) & (tree.tree_.n_node_samples == 1)
27842785
)
27852786
assert_allclose(tree.tree_.impurity[leaves_idx], 0.0)
2787+
2788+
2789+
def test_build_pruned_tree_py():
2790+
"""Test pruning a tree with the Python caller of the Cythonized prune tree."""
2791+
tree = DecisionTreeClassifier(random_state=0, max_depth=1)
2792+
tree.fit(iris.data, iris.target)
2793+
2794+
n_classes = np.atleast_1d(tree.n_classes_)
2795+
pruned_tree = CythonTree(tree.n_features_in_, n_classes, tree.n_outputs_)
2796+
2797+
# only keep the root note
2798+
leave_in_subtree = np.zeros(tree.tree_.node_count, dtype=np.uint8)
2799+
leave_in_subtree[0] = 1
2800+
_build_pruned_tree_py(pruned_tree, tree.tree_, leave_in_subtree)
2801+
2802+
assert tree.tree_.node_count == 3
2803+
assert pruned_tree.node_count == 1
2804+
with pytest.raises(AssertionError):
2805+
assert_array_equal(tree.tree_.value, pruned_tree.value)
2806+
assert_array_equal(tree.tree_.value[0], pruned_tree.value[0])
2807+
2808+
# now keep all the leaves
2809+
pruned_tree = CythonTree(tree.n_features_in_, n_classes, tree.n_outputs_)
2810+
leave_in_subtree = np.zeros(tree.tree_.node_count, dtype=np.uint8)
2811+
leave_in_subtree[1:] = 1
2812+
2813+
# Prune the tree
2814+
_build_pruned_tree_py(pruned_tree, tree.tree_, leave_in_subtree)
2815+
assert tree.tree_.node_count == 3
2816+
assert pruned_tree.node_count == 3, pruned_tree.node_count
2817+
assert_array_equal(tree.tree_.value, pruned_tree.value)
2818+
2819+
2820+
def test_build_pruned_tree_infinite_loop():
2821+
"""Test pruning a tree does not result in an infinite loop."""
2822+
2823+
# Create a tree with root and two children
2824+
tree = DecisionTreeClassifier(random_state=0, max_depth=1)
2825+
tree.fit(iris.data, iris.target)
2826+
n_classes = np.atleast_1d(tree.n_classes_)
2827+
pruned_tree = CythonTree(tree.n_features_in_, n_classes, tree.n_outputs_)
2828+
2829+
# only keeping one child as a leaf results in an improper tree
2830+
leave_in_subtree = np.zeros(tree.tree_.node_count, dtype=np.uint8)
2831+
leave_in_subtree[1] = 1
2832+
with pytest.raises(
2833+
ValueError, match="Node has reached a leaf in the original tree"
2834+
):
2835+
_build_pruned_tree_py(pruned_tree, tree.tree_, leave_in_subtree)

0 commit comments

Comments
 (0)