Skip to content

Commit d455aa1

Browse files
committed
cimport _build pruned tree
Signed-off-by: Adam Li <adam2392@gmail.com>
1 parent f0f69be commit d455aa1

File tree

2 files changed

+19
-34
lines changed

2 files changed

+19
-34
lines changed

sklearn/tree/_splitter.pyx

Lines changed: 11 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -341,6 +341,8 @@ cdef class Splitter(BaseSplitter):
341341
This is typically a metric that is cheaply computed given the
342342
current proposed split, which is stored as a the `current_split`
343343
argument.
344+
345+
Returns 1 if not a valid split, and 0 if it is.
344346
"""
345347
cdef intp_t min_samples_leaf = self.min_samples_leaf
346348
cdef intp_t end_non_missing = self.end - n_missing
@@ -418,8 +420,6 @@ cdef inline intp_t node_split_best(
418420
Criterion criterion,
419421
SplitRecord* split,
420422
ParentInfo* parent_record,
421-
# bint with_monotonic_cst,
422-
# const int8_t[:] monotonic_cst,
423423
) except -1 nogil:
424424
"""Find the best split on node samples[start:end]
425425
@@ -566,25 +566,7 @@ cdef inline intp_t node_split_best(
566566

567567
current_split.pos = p
568568

569-
# Reject if monotonicity constraints are not satisfied
570-
if (
571-
with_monotonic_cst and
572-
monotonic_cst[current_split.feature] != 0 and
573-
not criterion.check_monotonicity(
574-
monotonic_cst[current_split.feature],
575-
lower_bound,
576-
upper_bound,
577-
)
578-
):
579-
continue
580-
581569
# Reject if min_samples_leaf is not guaranteed
582-
if missing_go_to_left:
583-
n_left = current_split.pos - splitter.start + n_missing
584-
n_right = end_non_missing - current_split.pos
585-
else:
586-
n_left = current_split.pos - splitter.start
587-
n_right = end_non_missing - current_split.pos + n_missing
588570
if splitter.check_presplit_conditions(&current_split, n_missing, missing_go_to_left) == 1:
589571
continue
590572

@@ -624,6 +606,13 @@ cdef inline intp_t node_split_best(
624606

625607
current_split.n_missing = n_missing
626608
if n_missing == 0:
609+
if missing_go_to_left:
610+
n_left = current_split.pos - splitter.start + n_missing
611+
n_right = end_non_missing - current_split.pos
612+
else:
613+
n_left = current_split.pos - splitter.start
614+
n_right = end_non_missing - current_split.pos + n_missing
615+
627616
current_split.missing_go_to_left = n_left > n_right
628617
else:
629618
current_split.missing_go_to_left = missing_go_to_left
@@ -938,10 +927,6 @@ cdef inline int node_split_random(
938927
criterion.reset()
939928
criterion.update(current_split.pos)
940929

941-
# Reject if min_weight_leaf is not satisfied
942-
if splitter.check_postsplit_conditions() == 1:
943-
continue
944-
945930
# Reject if monotonicity constraints are not satisfied
946931
if (
947932
with_monotonic_cst and
@@ -954,16 +939,8 @@ cdef inline int node_split_random(
954939
):
955940
continue
956941

957-
# Reject if monotonicity constraints are not satisfied
958-
if (
959-
with_monotonic_cst and
960-
monotonic_cst[current_split.feature] != 0 and
961-
not criterion.check_monotonicity(
962-
monotonic_cst[current_split.feature],
963-
lower_bound,
964-
upper_bound,
965-
)
966-
):
942+
# Reject if min_weight_leaf is not satisfied
943+
if splitter.check_postsplit_conditions() == 1:
967944
continue
968945

969946
current_proxy_improvement = criterion.proxy_impurity_improvement()

sklearn/tree/_tree.pxd

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,3 +191,11 @@ cdef class TreeBuilder:
191191
const float64_t[:, ::1] y,
192192
const float64_t[:] sample_weight,
193193
)
194+
195+
196+
cdef _build_pruned_tree(
197+
Tree tree, # OUT
198+
Tree orig_tree,
199+
const unsigned char[:] leaves_in_subtree,
200+
intp_t capacity
201+
)

0 commit comments

Comments
 (0)