Skip to content

Commit b8cc636

Browse files
refactor baby step
1 parent 78c3a1b commit b8cc636

File tree

2 files changed

+138
-182
lines changed

2 files changed

+138
-182
lines changed

sklearn/tree/_tree.pxd

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -43,18 +43,6 @@ cdef struct ParentInfo:
4343
float64_t impurity # the impurity of the parent
4444
intp_t n_constant_features # the number of constant features found in parent
4545

46-
ctypedef intp_t (*AddOrUpdateNodeFunc)(
47-
Tree tree,
48-
intp_t parent,
49-
bint is_left,
50-
bint is_leaf,
51-
SplitRecord* split_node,
52-
float64_t impurity,
53-
intp_t n_node_samples,
54-
float64_t weighted_n_node_samples,
55-
unsigned char missing_go_to_left
56-
) except -1 nogil
57-
5846
# A record on the stack for depth-first tree growing
5947
cdef struct StackRecord:
6048
intp_t start
@@ -114,8 +102,6 @@ cdef struct BuildEnv:
114102
StackRecord stack_record
115103

116104
ParentInfo parent_record
117-
118-
AddOrUpdateNodeFunc add_or_update_node
119105

120106

121107
cdef class BaseTree:

sklearn/tree/_tree.pyx

Lines changed: 138 additions & 168 deletions
Original file line numberDiff line numberDiff line change
@@ -153,44 +153,6 @@ cdef class TreeBuilder:
153153

154154

155155
# Depth first builder ---------------------------------------------------------
156-
157-
158-
cdef intp_t tree_add_node(
159-
Tree tree,
160-
intp_t parent,
161-
bint is_left,
162-
bint is_leaf,
163-
SplitRecord* split_node,
164-
float64_t impurity,
165-
intp_t n_node_samples,
166-
float64_t weighted_n_node_samples,
167-
unsigned char missing_go_to_left
168-
) except -1 nogil:
169-
return tree._add_node(
170-
parent, is_left, is_leaf,
171-
split_node, impurity,
172-
n_node_samples, weighted_n_node_samples,
173-
missing_go_to_left
174-
)
175-
176-
cdef intp_t tree_update_node(
177-
Tree tree,
178-
intp_t parent,
179-
bint is_left,
180-
bint is_leaf,
181-
SplitRecord* split_node,
182-
float64_t impurity,
183-
intp_t n_node_samples,
184-
float64_t weighted_n_node_samples,
185-
unsigned char missing_go_to_left
186-
) except -1 nogil:
187-
return tree._update_node(
188-
parent, is_left, is_leaf,
189-
split_node, impurity,
190-
n_node_samples, weighted_n_node_samples,
191-
missing_go_to_left
192-
)
193-
194156
cdef class DepthFirstTreeBuilder(TreeBuilder):
195157
"""Build a decision tree in depth-first fashion."""
196158

@@ -289,6 +251,141 @@ cdef class DepthFirstTreeBuilder(TreeBuilder):
289251
# convert dict to numpy array and store value
290252
self.initial_roots = np.array(list(false_roots.items()))
291253

254+
cdef void _build_body(self, Tree tree, Splitter splitter, BuildEnv* e, bint update) noexcept nogil:
255+
while not e.target_stack.empty():
256+
e.stack_record = e.target_stack.top()
257+
e.target_stack.pop()
258+
259+
e.start = e.stack_record.start
260+
e.end = e.stack_record.end
261+
e.depth = e.stack_record.depth
262+
e.parent = e.stack_record.parent
263+
e.is_left = e.stack_record.is_left
264+
e.parent_record.impurity = e.stack_record.impurity
265+
e.parent_record.n_constant_features = e.stack_record.n_constant_features
266+
e.parent_record.lower_bound = e.stack_record.lower_bound
267+
e.parent_record.upper_bound = e.stack_record.upper_bound
268+
269+
e.n_node_samples = e.end - e.start
270+
splitter.node_reset(e.start, e.end, &e.weighted_n_node_samples)
271+
272+
e.is_leaf = (e.depth >= e.max_depth or
273+
e.n_node_samples < e.min_samples_split or
274+
e.n_node_samples < 2 * e.min_samples_leaf or
275+
e.weighted_n_node_samples < 2 * e.min_weight_leaf)
276+
277+
if e.first:
278+
e.parent_record.impurity = splitter.node_impurity()
279+
e.first = 0
280+
281+
# impurity == 0 with tolerance due to rounding errors
282+
e.is_leaf = e.is_leaf or e.parent_record.impurity <= EPSILON
283+
284+
if not e.is_leaf:
285+
splitter.node_split(
286+
&e.parent_record,
287+
e.split,
288+
)
289+
290+
# If EPSILON=0 in the below comparison, float precision
291+
# issues stop splitting, producing trees that are
292+
# dissimilar to v0.18
293+
e.is_leaf = (e.is_leaf or e.split.pos >= e.end or
294+
(e.split.improvement + EPSILON <
295+
e.min_impurity_decrease))
296+
297+
if update == 1:
298+
e.node_id = tree._update_node(
299+
e.parent, e.is_left, e.is_leaf, e.split,
300+
e.parent_record.impurity, e.n_node_samples, e.weighted_n_node_samples,
301+
e.split.missing_go_to_left
302+
)
303+
else:
304+
e.node_id = tree._add_node(
305+
e.parent, e.is_left, e.is_leaf, e.split,
306+
e.parent_record.impurity, e.n_node_samples, e.weighted_n_node_samples,
307+
e.split.missing_go_to_left
308+
)
309+
310+
if e.node_id == INTPTR_MAX:
311+
e.rc = -1
312+
break
313+
314+
# Store value for all nodes, to facilitate tree/model
315+
# inspection and interpretation
316+
splitter.node_value(tree.value + e.node_id * tree.value_stride)
317+
if splitter.with_monotonic_cst:
318+
splitter.clip_node_value(
319+
tree.value + e.node_id * tree.value_stride,
320+
e.parent_record.lower_bound,
321+
e.parent_record.upper_bound
322+
)
323+
324+
if not e.is_leaf:
325+
if (
326+
not splitter.with_monotonic_cst or
327+
splitter.monotonic_cst[e.split.feature] == 0
328+
):
329+
# Split on a feature with no monotonicity constraint
330+
331+
# Current bounds must always be propagated to both children.
332+
# If a monotonic constraint is active, bounds are used in
333+
# node value clipping.
334+
e.left_child_min = e.right_child_min = e.parent_record.lower_bound
335+
e.left_child_max = e.right_child_max = e.parent_record.upper_bound
336+
elif splitter.monotonic_cst[e.split.feature] == 1:
337+
# Split on a feature with monotonic increase constraint
338+
e.left_child_min = e.parent_record.lower_bound
339+
e.right_child_max = e.parent_record.upper_bound
340+
341+
# Lower bound for right child and upper bound for left child
342+
# are set to the same value.
343+
e.middle_value = splitter.criterion.middle_value()
344+
e.right_child_min = e.middle_value
345+
e.left_child_max = e.middle_value
346+
else: # i.e. splitter.monotonic_cst[e.split.feature] == -1
347+
# Split on a feature with monotonic decrease constraint
348+
e.right_child_min = e.parent_record.lower_bound
349+
e.left_child_max = e.parent_record.upper_bound
350+
351+
# Lower bound for left child and upper bound for right child
352+
# are set to the same value.
353+
e.middle_value = splitter.criterion.middle_value()
354+
e.left_child_min = e.middle_value
355+
e.right_child_max = e.middle_value
356+
357+
# Push right child on stack
358+
e.builder_stack.push({
359+
"start": e.split.pos,
360+
"end": e.end,
361+
"depth": e.depth + 1,
362+
"parent": e.node_id,
363+
"is_left": 0,
364+
"impurity": e.split.impurity_right,
365+
"n_constant_features": e.parent_record.n_constant_features,
366+
"lower_bound": e.right_child_min,
367+
"upper_bound": e.right_child_max,
368+
})
369+
370+
# Push left child on stack
371+
e.builder_stack.push({
372+
"start": e.start,
373+
"end": e.split.pos,
374+
"depth": e.depth + 1,
375+
"parent": e.node_id,
376+
"is_left": 1,
377+
"impurity": e.split.impurity_left,
378+
"n_constant_features": e.parent_record.n_constant_features,
379+
"lower_bound": e.left_child_min,
380+
"upper_bound": e.left_child_max,
381+
})
382+
elif e.store_leaf_values and e.is_leaf:
383+
# copy leaf values to leaf_values array
384+
splitter.node_samples(tree.value_samples[e.node_id])
385+
386+
if e.depth > e.max_depth_seen:
387+
e.max_depth_seen = e.depth
388+
292389
cpdef build(
293390
self,
294391
Tree tree,
@@ -379,136 +476,9 @@ cdef class DepthFirstTreeBuilder(TreeBuilder):
379476

380477
with nogil:
381478
e.target_stack = &e.update_stack
382-
e.add_or_update_node = tree_update_node
383-
while not e.target_stack.empty():
384-
e.stack_record = e.target_stack.top()
385-
e.target_stack.pop()
386-
387-
e.start = e.stack_record.start
388-
e.end = e.stack_record.end
389-
e.depth = e.stack_record.depth
390-
e.parent = e.stack_record.parent
391-
e.is_left = e.stack_record.is_left
392-
e.parent_record.impurity = e.stack_record.impurity
393-
e.parent_record.n_constant_features = e.stack_record.n_constant_features
394-
e.parent_record.lower_bound = e.stack_record.lower_bound
395-
e.parent_record.upper_bound = e.stack_record.upper_bound
396-
397-
e.n_node_samples = e.end - e.start
398-
splitter.node_reset(e.start, e.end, &e.weighted_n_node_samples)
399-
400-
e.is_leaf = (e.depth >= e.max_depth or
401-
e.n_node_samples < e.min_samples_split or
402-
e.n_node_samples < 2 * e.min_samples_leaf or
403-
e.weighted_n_node_samples < 2 * e.min_weight_leaf)
404-
405-
if e.first:
406-
e.parent_record.impurity = splitter.node_impurity()
407-
e.first = 0
408-
409-
# impurity == 0 with tolerance due to rounding errors
410-
e.is_leaf = e.is_leaf or e.parent_record.impurity <= EPSILON
411-
412-
if not e.is_leaf:
413-
splitter.node_split(
414-
&e.parent_record,
415-
e.split,
416-
)
417-
418-
# If EPSILON=0 in the below comparison, float precision
419-
# issues stop splitting, producing trees that are
420-
# dissimilar to v0.18
421-
e.is_leaf = (e.is_leaf or e.split.pos >= e.end or
422-
(e.split.improvement + EPSILON <
423-
e.min_impurity_decrease))
424-
425-
e.node_id = e.add_or_update_node(
426-
tree, e.parent, e.is_left, e.is_leaf, e.split,
427-
e.parent_record.impurity, e.n_node_samples, e.weighted_n_node_samples,
428-
e.split.missing_go_to_left
429-
)
430-
431-
if e.node_id == INTPTR_MAX:
432-
e.rc = -1
433-
break
434-
435-
# Store value for all nodes, to facilitate tree/model
436-
# inspection and interpretation
437-
splitter.node_value(tree.value + e.node_id * tree.value_stride)
438-
if splitter.with_monotonic_cst:
439-
splitter.clip_node_value(
440-
tree.value + e.node_id * tree.value_stride,
441-
e.parent_record.lower_bound,
442-
e.parent_record.upper_bound
443-
)
444-
445-
if not e.is_leaf:
446-
if (
447-
not splitter.with_monotonic_cst or
448-
splitter.monotonic_cst[e.split.feature] == 0
449-
):
450-
# Split on a feature with no monotonicity constraint
451-
452-
# Current bounds must always be propagated to both children.
453-
# If a monotonic constraint is active, bounds are used in
454-
# node value clipping.
455-
e.left_child_min = e.right_child_min = e.parent_record.lower_bound
456-
e.left_child_max = e.right_child_max = e.parent_record.upper_bound
457-
elif splitter.monotonic_cst[e.split.feature] == 1:
458-
# Split on a feature with monotonic increase constraint
459-
e.left_child_min = e.parent_record.lower_bound
460-
e.right_child_max = e.parent_record.upper_bound
461-
462-
# Lower bound for right child and upper bound for left child
463-
# are set to the same value.
464-
e.middle_value = splitter.criterion.middle_value()
465-
e.right_child_min = e.middle_value
466-
e.left_child_max = e.middle_value
467-
else: # i.e. splitter.monotonic_cst[e.split.feature] == -1
468-
# Split on a feature with monotonic decrease constraint
469-
e.right_child_min = e.parent_record.lower_bound
470-
e.left_child_max = e.parent_record.upper_bound
471-
472-
# Lower bound for left child and upper bound for right child
473-
# are set to the same value.
474-
e.middle_value = splitter.criterion.middle_value()
475-
e.left_child_min = e.middle_value
476-
e.right_child_max = e.middle_value
477-
478-
# Push right child on stack
479-
e.builder_stack.push({
480-
"start": e.split.pos,
481-
"end": e.end,
482-
"depth": e.depth + 1,
483-
"parent": e.node_id,
484-
"is_left": 0,
485-
"impurity": e.split.impurity_right,
486-
"n_constant_features": e.parent_record.n_constant_features,
487-
"lower_bound": e.right_child_min,
488-
"upper_bound": e.right_child_max,
489-
})
490-
491-
# Push left child on stack
492-
e.builder_stack.push({
493-
"start": e.start,
494-
"end": e.split.pos,
495-
"depth": e.depth + 1,
496-
"parent": e.node_id,
497-
"is_left": 1,
498-
"impurity": e.split.impurity_left,
499-
"n_constant_features": e.parent_record.n_constant_features,
500-
"lower_bound": e.left_child_min,
501-
"upper_bound": e.left_child_max,
502-
})
503-
elif e.store_leaf_values and e.is_leaf:
504-
# copy leaf values to leaf_values array
505-
splitter.node_samples(tree.value_samples[e.node_id])
506-
507-
if e.depth > e.max_depth_seen:
508-
e.max_depth_seen = e.depth
479+
self._build_body(tree, splitter, &e, 1)
509480

510481
e.target_stack = &e.builder_stack
511-
e.add_or_update_node = tree_add_node
512482
while not e.target_stack.empty():
513483
e.stack_record = e.target_stack.top()
514484
e.target_stack.pop()
@@ -551,8 +521,8 @@ cdef class DepthFirstTreeBuilder(TreeBuilder):
551521
(e.split.improvement + EPSILON <
552522
e.min_impurity_decrease))
553523

554-
e.node_id = e.add_or_update_node(
555-
tree, e.parent, e.is_left, e.is_leaf, e.split,
524+
e.node_id = tree._add_node(
525+
e.parent, e.is_left, e.is_leaf, e.split,
556526
e.parent_record.impurity, e.n_node_samples,
557527
e.weighted_n_node_samples, e.split.missing_go_to_left
558528
)

0 commit comments

Comments
 (0)