Skip to content

Commit f225658

Browse files
update node refactor more baby steps
1 parent b8cc636 commit f225658

File tree

1 file changed

+1
-126
lines changed

1 file changed

+1
-126
lines changed

sklearn/tree/_tree.pyx

Lines changed: 1 addition & 126 deletions
Original file line numberDiff line numberDiff line change
@@ -479,132 +479,7 @@ cdef class DepthFirstTreeBuilder(TreeBuilder):
479479
self._build_body(tree, splitter, &e, 1)
480480

481481
e.target_stack = &e.builder_stack
482-
while not e.target_stack.empty():
483-
e.stack_record = e.target_stack.top()
484-
e.target_stack.pop()
485-
486-
e.start = e.stack_record.start
487-
e.end = e.stack_record.end
488-
e.depth = e.stack_record.depth
489-
e.parent = e.stack_record.parent
490-
e.is_left = e.stack_record.is_left
491-
e.parent_record.impurity = e.stack_record.impurity
492-
e.parent_record.n_constant_features = e.stack_record.n_constant_features
493-
e.parent_record.lower_bound = e.stack_record.lower_bound
494-
e.parent_record.upper_bound = e.stack_record.upper_bound
495-
496-
e.n_node_samples = e.end - e.start
497-
splitter.node_reset(e.start, e.end, &e.weighted_n_node_samples)
498-
499-
e.is_leaf = (e.depth >= e.max_depth or
500-
e.n_node_samples < e.min_samples_split or
501-
e.n_node_samples < 2 * e.min_samples_leaf or
502-
e.weighted_n_node_samples < 2 * e.min_weight_leaf)
503-
504-
if e.first:
505-
e.parent_record.impurity = splitter.node_impurity()
506-
e.first=0
507-
508-
# impurity == 0 with tolerance due to rounding errors
509-
e.is_leaf = e.is_leaf or e.parent_record.impurity <= EPSILON
510-
511-
if not e.is_leaf:
512-
splitter.node_split(
513-
&e.parent_record,
514-
e.split,
515-
)
516-
517-
# If EPSILON=0 in the below comparison, float precision
518-
# issues stop splitting, producing trees that are
519-
# dissimilar to v0.18
520-
e.is_leaf = (e.is_leaf or e.split.pos >= e.end or
521-
(e.split.improvement + EPSILON <
522-
e.min_impurity_decrease))
523-
524-
e.node_id = tree._add_node(
525-
e.parent, e.is_left, e.is_leaf, e.split,
526-
e.parent_record.impurity, e.n_node_samples,
527-
e.weighted_n_node_samples, e.split.missing_go_to_left
528-
)
529-
530-
if e.node_id == INTPTR_MAX:
531-
e.rc = -1
532-
break
533-
534-
# Store value for all nodes, to facilitate tree/model
535-
# inspection and interpretation
536-
splitter.node_value(tree.value + e.node_id * tree.value_stride)
537-
if splitter.with_monotonic_cst:
538-
splitter.clip_node_value(
539-
tree.value + e.node_id * tree.value_stride,
540-
e.parent_record.lower_bound,
541-
e.parent_record.upper_bound
542-
)
543-
544-
if not e.is_leaf:
545-
if (
546-
not splitter.with_monotonic_cst or
547-
splitter.monotonic_cst[e.split.feature] == 0
548-
):
549-
# Split on a feature with no monotonicity constraint
550-
551-
# Current bounds must always be propagated to both children.
552-
# If a monotonic constraint is active, bounds are used in
553-
# node value clipping.
554-
e.left_child_min = e.right_child_min = e.parent_record.lower_bound
555-
e.left_child_max = e.right_child_max = e.parent_record.upper_bound
556-
elif splitter.monotonic_cst[e.split.feature] == 1:
557-
# Split on a feature with monotonic increase constraint
558-
e.left_child_min = e.parent_record.lower_bound
559-
e.right_child_max = e.parent_record.upper_bound
560-
561-
# Lower bound for right child and upper bound for left child
562-
# are set to the same value.
563-
e.middle_value = splitter.criterion.middle_value()
564-
e.right_child_min = e.middle_value
565-
e.left_child_max = e.middle_value
566-
else: # i.e. splitter.monotonic_cst[e.split.feature] == -1
567-
# Split on a feature with monotonic decrease constraint
568-
e.right_child_min = e.parent_record.lower_bound
569-
e.left_child_max = e.parent_record.upper_bound
570-
571-
# Lower bound for left child and upper bound for right child
572-
# are set to the same value.
573-
e.middle_value = splitter.criterion.middle_value()
574-
e.left_child_min = e.middle_value
575-
e.right_child_max = e.middle_value
576-
577-
# Push right child on stack
578-
e.builder_stack.push({
579-
"start": e.split.pos,
580-
"end": e.end,
581-
"depth": e.depth + 1,
582-
"parent": e.node_id,
583-
"is_left": 0,
584-
"impurity": e.split.impurity_right,
585-
"n_constant_features": e.parent_record.n_constant_features,
586-
"lower_bound": e.right_child_min,
587-
"upper_bound": e.right_child_max,
588-
})
589-
590-
# Push left child on stack
591-
e.builder_stack.push({
592-
"start": e.start,
593-
"end": e.split.pos,
594-
"depth": e.depth + 1,
595-
"parent": e.node_id,
596-
"is_left": 1,
597-
"impurity": e.split.impurity_left,
598-
"n_constant_features": e.parent_record.n_constant_features,
599-
"lower_bound": e.left_child_min,
600-
"upper_bound": e.left_child_max,
601-
})
602-
elif e.store_leaf_values and e.is_leaf:
603-
# copy leaf values to leaf_values array
604-
splitter.node_samples(tree.value_samples[e.node_id])
605-
606-
if e.depth > e.max_depth_seen:
607-
e.max_depth_seen = e.depth
482+
self._build_body(tree, splitter, &e, 0)
608483

609484
if e.rc >= 0:
610485
e.rc = tree._resize_c(tree.node_count)

0 commit comments

Comments
 (0)