Skip to content

Commit 877a822

Browse files
commented changes to tree
1 parent f655401 commit 877a822

File tree

2 files changed

+15
-57
lines changed

2 files changed

+15
-57
lines changed

sklearn/tree/_tree.pxd

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,9 @@ cdef extern from "<stack>" namespace "std" nogil:
6969
void push(T&) except + # Raise c++ exception for bad_alloc -> MemoryError
7070
T& top()
7171

72+
# A large portion of the tree build function was duplicated almost verbatim in the
73+
# neurodata fork of sklearn. We refactor that out into its own function, and it's
74+
# most convenient to encapsulate all the tree build state into its own env struct.
7275
cdef enum TreeBuildStatus:
7376
OK = 0
7477
EXCEPTION_OR_MEMORY_ERROR = -1
@@ -113,6 +116,9 @@ cdef struct BuildEnv:
113116

114117
ParentInfo parent_record
115118

119+
120+
# We add tree build events to notify interested parties of tree build state.
121+
# Only current relevant events are implemented.
116122
cdef enum TreeBuildEvent:
117123
ADD_NODE = 1
118124
UPDATE_NODE = 2
@@ -263,6 +269,7 @@ cdef class TreeBuilder:
263269

264270
cdef unsigned char store_leaf_values # Whether to store leaf values
265271

272+
# event broker for distributing tree build events
266273
cdef EventBroker event_broker
267274

268275

sklearn/tree/_tree.pyx

Lines changed: 8 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -269,13 +269,11 @@ cdef class DepthFirstTreeBuilder(TreeBuilder):
269269

270270
cdef void _build_body(self, EventBroker broker, Tree tree, Splitter splitter, BuildEnv* e, bint update) noexcept nogil:
271271
cdef TreeBuildEvent evt
272+
273+
# payloads for different tree build events
272274
cdef TreeBuildSetActiveParentEventData parent_event_data
273275
cdef TreeBuildAddNodeEventData add_update_node_data
274276

275-
#with gil:
276-
# print("")
277-
# print("_build_body")
278-
279277
while not e.target_stack.empty():
280278
e.stack_record = e.target_stack.top()
281279
e.target_stack.pop()
@@ -295,15 +293,10 @@ cdef class DepthFirstTreeBuilder(TreeBuilder):
295293
parent_event_data.parent_node_id = e.stack_record.parent
296294
parent_event_data.child_is_left = e.stack_record.is_left
297295

298-
#with gil:
299-
# print(f"start {e.start}")
300-
# print(f"end {e.end}")
301-
# print(f"parent {<int>e.parent}")
302-
# print(f"is_left {e.is_left}")
303-
# print(f"n_node_samples {e.n_node_samples}")
304-
# print(f"parent_node_id {parent_event_data.parent_node_id}")
305-
# print(f"child_is_left {parent_event_data.child_is_left}")
306-
296+
# tree build state is kind of weird as implemented because
297+
# the child node id is assigned after child node creation, and all
298+
# situational awareness during creation is referenced to the parent node.
299+
# so we fire an event indicating the current active parent.
307300
if not broker.fire_event(TreeBuildEvent.SET_ACTIVE_PARENT, &parent_event_data):
308301
e.rc = TreeBuildStatus.EVENT_ERROR
309302
break
@@ -315,29 +308,13 @@ cdef class DepthFirstTreeBuilder(TreeBuilder):
315308
e.n_node_samples < 2 * e.min_samples_leaf or
316309
e.weighted_n_node_samples < 2 * e.min_weight_leaf)
317310

318-
#with gil:
319-
# print("")
320-
# print(f"*** IS_LEAF ***")
321-
# print(f"is_leaf = {e.is_leaf}")
322-
# print(f"depth = {e.depth}")
323-
# print(f"max_depth = {e.max_depth}")
324-
# print(f"n_node_samples = {e.n_node_samples}")
325-
# print(f"min_samples_split = {e.min_samples_split}")
326-
# print(f"min_samples_leaf = {e.min_samples_leaf}")
327-
# print(f"weighted_n_node_samples = {e.weighted_n_node_samples}")
328-
# print(f"min_weight_leaf = {e.min_weight_leaf}")
329-
330311
if e.first:
331312
e.parent_record.impurity = splitter.node_impurity()
332313
e.first = 0
333314

334315
# impurity == 0 with tolerance due to rounding errors
335316
e.is_leaf = e.is_leaf or e.parent_record.impurity <= EPSILON
336317

337-
#with gil:
338-
# print(f"is_leaf 2 = {e.is_leaf}")
339-
# print(f"parent_record.impurity = {e.parent_record.impurity}")
340-
341318
add_update_node_data.parent_node_id = e.parent
342319
add_update_node_data.is_left = e.is_left
343320
add_update_node_data.feature = -1
@@ -349,9 +326,6 @@ cdef class DepthFirstTreeBuilder(TreeBuilder):
349326
e.split,
350327
)
351328

352-
#with gil:
353-
# print("_build_body checkpoint 1")
354-
355329
# If EPSILON=0 in the below comparison, float precision
356330
# issues stop splitting, producing trees that are
357331
# dissimilar to v0.18
@@ -363,14 +337,6 @@ cdef class DepthFirstTreeBuilder(TreeBuilder):
363337
add_update_node_data.feature = e.split.feature
364338
add_update_node_data.split_point = e.split.threshold
365339

366-
#with gil:
367-
# print("_build_body checkpoint 2")
368-
# print(f"is_leaf 3 = {e.is_leaf}")
369-
# print(f"split.pos = {e.split.pos}")
370-
# print(f"end = {e.end}")
371-
# print(f"split.improvement = {e.split.improvement}")
372-
# print(f"min_impurity_decrease = {e.min_impurity_decrease}")
373-
# print(f"feature = {e.split.feature}")
374340

375341
if update == 1:
376342
e.node_id = tree._update_node(
@@ -387,29 +353,17 @@ cdef class DepthFirstTreeBuilder(TreeBuilder):
387353
)
388354
evt = TreeBuildEvent.ADD_NODE
389355

390-
#with gil:
391-
# print("_build_body checkpoint 3")
392-
393356
if e.node_id == INTPTR_MAX:
394-
#with gil:
395-
# print("_build_body checkpoint 3.25")
396357
e.rc = TreeBuildStatus.EXCEPTION_OR_MEMORY_ERROR
397358
break
398359

399-
#with gil:
400-
# print("_build_body checkpoint 3.5")
401-
402360
add_update_node_data.node_id = e.node_id
403361
add_update_node_data.is_leaf = e.is_leaf
404362

405-
#with gil:
406-
# print("_build_body checkpoint 3.6")
407-
363+
# now that all relevant information has been accumulated,
364+
# notify interested parties that a node has been added/updated
408365
broker.fire_event(evt, &add_update_node_data)
409366

410-
#with gil:
411-
# print("_build_body checkpoint 4")
412-
413367
# Store value for all nodes, to facilitate tree/model
414368
# inspection and interpretation
415369
splitter.node_value(tree.value + e.node_id * tree.value_stride)
@@ -420,9 +374,6 @@ cdef class DepthFirstTreeBuilder(TreeBuilder):
420374
e.parent_record.upper_bound
421375
)
422376

423-
#with gil:
424-
# print("_build_body checkpoint 5")
425-
426377
if not e.is_leaf:
427378
if (
428379
not splitter.with_monotonic_cst or

0 commit comments

Comments
 (0)