Skip to content

Commit 3b16b8f

Browse files
commented honesty module
1 parent 877a822 commit 3b16b8f

File tree

2 files changed

+26
-119
lines changed

2 files changed

+26
-119
lines changed

sklearn/tree/_honesty.pxd

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,19 @@
44

55
# See _honesty.pyx for details.
66

7+
# Here we cash in the architectural changes/additions we made to Splitter and
8+
# TreeBuilder. We implement this as an honest module not dependent on any particular
9+
# type of Tree so that it can be composed into any type of Tree.
10+
#
11+
# The general ideas are that we:
12+
# 1. inject honest split rejection criteria into Splitter
13+
# 2. listen to tree build events fired by TreeBuilder to build a shadow tree
14+
# which contains the honest sample
15+
#
16+
# So we implement honest split rejection criteria for injection into Splitter,
17+
# and event handlers which construct the shadow tree in response to events fired
18+
# by TreeBuilder.
19+
720
from ._events cimport EventData, EventHandler, EventHandlerEnv, EventType
821
from ._partitioner cimport Partitioner
922
from ._splitter cimport (
@@ -28,6 +41,10 @@ from ..utils._typedefs cimport float32_t, float64_t, intp_t, int32_t, uint32_t
2841
from libcpp.vector cimport vector
2942

3043

44+
# We do a much simplified tree model, barely more than enough to define the
45+
# partition extents in the honest-masked data array corresponding to the node's
46+
# elements. We store it in a vector indexed by the corresponding node IDs in the
47+
# "structure" tree.
3148
cdef struct Interval:
3249
intp_t start_idx # index into samples
3350
intp_t n

sklearn/tree/_honesty.pyx

Lines changed: 9 additions & 119 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,9 @@ cdef class Honesty:
8282
X, samples, feature_values, missing_values_in_feature_mask
8383
)
8484

85+
# The Criterion classes are quite stateful, and since we wish to reuse them
86+
# to maintain behavior consistent with them, we have to do some implementational
87+
# shenanigans like this.
8588
def init_criterion(
8689
self,
8790
Criterion criterion,
@@ -158,10 +161,6 @@ cdef bint _handle_set_active_parent(
158161
EventHandlerEnv handler_env,
159162
EventData event_data
160163
) noexcept nogil:
161-
#with gil:
162-
# print("")
163-
# print("in _handle_set_active_parent")
164-
165164
if event_type != TreeBuildEvent.SET_ACTIVE_PARENT:
166165
return True
167166

@@ -178,10 +177,6 @@ cdef bint _handle_set_active_parent(
178177
node.split_idx = 0
179178
node.split_value = NAN
180179

181-
#with gil:
182-
# print(f"data = {data.parent_node_id}")
183-
# print(f"env = {env.tree.size()}")
184-
185180
if data.parent_node_id < 0:
186181
env.active_parent = NULL
187182
node.start_idx = 0
@@ -195,20 +190,8 @@ cdef bint _handle_set_active_parent(
195190
node.start_idx = env.active_parent.split_idx
196191
node.n = env.active_parent.n - env.active_parent.split_idx
197192

198-
#with gil:
199-
# print("in _handle_set_active_parent")
200-
# print(f"data = {data.parent_node_id}")
201-
# print(f"env = {env.tree.size()}")
202-
# print(f"active_is_left = {env.active_is_left}")
203-
# print(f"node.start_idx = {node.start_idx}")
204-
# print(f"node.n = {node.n}")
205-
206193
(<Views>env.data_views).partitioner.init_node_split(node.start_idx, node.start_idx + node.n)
207194

208-
#with gil:
209-
# print("returning")
210-
# print("")
211-
212195
return True
213196

214197
cdef class SetActiveParentHandler(EventHandler):
@@ -224,10 +207,6 @@ cdef bint _handle_sort_feature(
224207
EventHandlerEnv handler_env,
225208
EventData event_data
226209
) noexcept nogil:
227-
#with gil:
228-
# print("")
229-
# print("in _handle_sort_feature")
230-
231210
if event_type != NodeSplitEvent.SORT_FEATURE:
232211
return True
233212

@@ -239,20 +218,11 @@ cdef bint _handle_sort_feature(
239218
node.split_idx = 0
240219
node.split_value = NAN
241220

242-
#with gil:
243-
# print(f"data.feature = {data.feature}")
244-
# print(f"node.feature = {node.feature}")
245-
# print(f"node.split_idx = {node.split_idx}")
246-
# print(f"node.split_value = {node.split_value}")
247-
248221
(<Views>env.data_views).partitioner.sort_samples_and_feature_values(node.feature)
249222

250-
#with gil:
251-
# print("returning")
252-
# print("")
253-
254223
return True
255224

225+
# When the structure tree sorts by a feature, we must do the same
256226
cdef class NodeSortFeatureHandler(EventHandler):
257227
def __cinit__(self, Honesty h):
258228
self.event_types = np.array([NodeSplitEvent.SORT_FEATURE], dtype=np.int32)
@@ -266,15 +236,9 @@ cdef bint _handle_add_node(
266236
EventHandlerEnv handler_env,
267237
EventData event_data
268238
) noexcept nogil:
269-
#with gil:
270-
# print("_handle_add_node checkpoint 1")
271-
272239
if event_type != TreeBuildEvent.ADD_NODE:
273240
return True
274241

275-
#with gil:
276-
#print("_handle_add_node checkpoint 2")
277-
278242
cdef HonestEnv* env = <HonestEnv*>handler_env
279243
cdef const float32_t[:, :] X = (<Views>env.data_views).X
280244
cdef intp_t[::1] samples = (<Views>env.data_views).samples
@@ -284,36 +248,15 @@ cdef bint _handle_add_node(
284248
cdef Interval *interval = NULL
285249
cdef Interval *parent = NULL
286250

287-
#with gil:
288-
# print("_handle_add_node checkpoint 3")
289-
290251
if data.node_id >= size:
291-
#with gil:
292-
# print("resizing")
293-
# print(f"node_id = {data.node_id}")
294-
# print(f"old tree.size = {env.tree.size()}")
295252
# as a heuristic, assume a complete tree and add a level
296253
h = floor(fmax(0, log2(size)))
297254
env.tree.resize(size + <intp_t>pow(2, h + 1))
298255

299-
#with gil:
300-
# print(f"h = {h}")
301-
# print(f"log2(size) = {log2(size)}")
302-
# print(f"new size = {size + <intp_t>pow(2, h + 1)}")
303-
# print(f"new tree.size = {env.tree.size()}")
304-
305-
#with gil:
306-
# print("_handle_add_node checkpoint 4")
307-
# print(f"node_id = {data.node_id}")
308-
# print(f"tree.size = {env.tree.size()}")
309-
310256
interval = &(env.tree[data.node_id])
311257
interval.feature = data.feature
312258
interval.split_value = data.split_point
313259

314-
#with gil:
315-
# print("_handle_add_node checkpoint 5")
316-
317260
if data.parent_node_id < 0:
318261
# the node being added is the tree root
319262
interval.start_idx = 0
@@ -328,34 +271,22 @@ cdef bint _handle_add_node(
328271
interval.start_idx = parent.split_idx
329272
interval.n = parent.n - (parent.split_idx - parent.start_idx)
330273

331-
#with gil:
332-
# print("_handle_add_node checkpoint 6")
333-
334-
# *we* don't need to sort to find the split pos we'll need for partitioning,
335-
# but the partitioner internals are so stateful we had better just do it
336-
# to ensure that it's in the expected state
274+
# We also reuse Partitioner. *We* don't need to sort to find the split pos we'll
275+
# need for partitioning, but the partitioner internals are so stateful we had
276+
# better just do it to ensure that it's in the expected state
337277
(<Views>env.data_views).partitioner.init_node_split(interval.start_idx, interval.start_idx + interval.n)
338278
(<Views>env.data_views).partitioner.sort_samples_and_feature_values(interval.feature)
339279

340-
#with gil:
341-
# print("_handle_add_node checkpoint 7")
342-
343280
# count n_left to find split pos
344281
n_left = 0
345282
i = interval.start_idx
346283
feature_value = X[samples[i], interval.feature]
347284

348-
#with gil:
349-
# print("_handle_add_node checkpoint 8")
350-
351285
while (not isnan(feature_value)) and feature_value < interval.split_value and i < interval.start_idx + interval.n:
352286
n_left += 1
353287
i += 1
354288
feature_value = X[samples[i], interval.feature]
355289

356-
#with gil:
357-
# print("_handle_add_node checkpoint 9")
358-
359290
interval.split_idx = interval.start_idx + n_left
360291

361292
(<Views>env.data_views).partitioner.partition_samples_final(
@@ -364,26 +295,6 @@ cdef bint _handle_add_node(
364295

365296
env.node_count += 1
366297

367-
#with gil:
368-
# #print("_handle_add_node checkpoint 10")
369-
# print("")
370-
# print(f"parent_node_id = {data.parent_node_id}")
371-
# print(f"node_id = {data.node_id}")
372-
# print(f"is_leaf = {data.is_leaf}")
373-
# print(f"is_left = {data.is_left}")
374-
# print(f"feature = {data.feature}")
375-
# print(f"split_point = {data.split_point}")
376-
# print("---")
377-
# print(f"start_idx = {interval.start_idx}")
378-
# if parent is not NULL:
379-
# print(f"parent.start_idx = {parent.start_idx}")
380-
# print(f"parent.split_idx = {parent.split_idx}")
381-
# print(f"parent.n = {parent.n}")
382-
# print(f"n = {interval.n}")
383-
# print(f"feature = {interval.feature}")
384-
# print(f"split_idx = {interval.split_idx}")
385-
# print(f"split_value = {interval.split_value}")
386-
387298

388299
cdef class AddNodeHandler(EventHandler):
389300
def __cinit__(self, Honesty h):
@@ -404,9 +315,6 @@ cdef bint _trivial_condition(
404315
float64_t upper_bound,
405316
SplitConditionEnv split_condition_env
406317
) noexcept nogil:
407-
#with gil:
408-
# print("TrivialCondition called")
409-
410318
return True
411319

412320
cdef class TrivialCondition(SplitCondition):
@@ -448,34 +356,16 @@ cdef bint _honest_min_sample_leaf_condition(
448356
n_left = node.split_idx - node.start_idx
449357
n_right = end_non_missing - node.split_idx + n_missing
450358

451-
#with gil:
452-
# print("")
453-
# print("in _honest_min_sample_leaf_condition")
454-
# print(f"min_samples_leaf = {min_samples_leaf}")
455-
# print(f"feature = {node.feature}")
456-
# print(f"start_idx = {node.start_idx}")
457-
# print(f"split_idx = {node.split_idx}")
458-
# print(f"n = {node.n}")
459-
# print(f"n_missing = {n_missing}")
460-
# print(f"end_non_missing = {end_non_missing}")
461-
# print(f"n_left = {n_left}")
462-
# print(f"n_right = {n_right}")
463-
# print(f"split_value = {split_value}")
464-
# if node.split_idx > 0:
465-
# print(f"X.feature_value left = {(<Views>env.honest_env.data_views).X[(<Views>env.honest_env.data_views).samples[node.split_idx - 1], node.feature]}")
466-
# print(f"X.feature_value right = {(<Views>env.honest_env.data_views).X[(<Views>env.honest_env.data_views).samples[node.split_idx], node.feature]}")
467-
468359
# Reject if min_samples_leaf is not guaranteed
469360
if n_left < min_samples_leaf or n_right < min_samples_leaf:
470361
#with gil:
471362
# print("returning False")
472363
return False
473364

474-
#with gil:
475-
# print("returning True")
476-
477365
return True
478366

367+
# Check that the honest set will have sufficient samples on each side of this
368+
# candidate split.
479369
cdef class HonestMinSamplesLeafCondition(SplitCondition):
480370
def __cinit__(self, Honesty h, intp_t min_samples):
481371
self._env.min_samples = min_samples

0 commit comments

Comments
 (0)