@@ -82,6 +82,9 @@ cdef class Honesty:
82
82
X, samples, feature_values, missing_values_in_feature_mask
83
83
)
84
84
85
+ # The Criterion classes are quite stateful, and since we wish to reuse them
86
+ # to maintain behavior consistent with them, we have to do some implementational
87
+ # shenanigans like this.
85
88
def init_criterion (
86
89
self ,
87
90
Criterion criterion ,
@@ -158,10 +161,6 @@ cdef bint _handle_set_active_parent(
158
161
EventHandlerEnv handler_env,
159
162
EventData event_data
160
163
) noexcept nogil:
161
- # with gil:
162
- # print("")
163
- # print("in _handle_set_active_parent")
164
-
165
164
if event_type != TreeBuildEvent.SET_ACTIVE_PARENT:
166
165
return True
167
166
@@ -178,10 +177,6 @@ cdef bint _handle_set_active_parent(
178
177
node.split_idx = 0
179
178
node.split_value = NAN
180
179
181
- # with gil:
182
- # print(f"data = {data.parent_node_id}")
183
- # print(f"env = {env.tree.size()}")
184
-
185
180
if data.parent_node_id < 0 :
186
181
env.active_parent = NULL
187
182
node.start_idx = 0
@@ -195,20 +190,8 @@ cdef bint _handle_set_active_parent(
195
190
node.start_idx = env.active_parent.split_idx
196
191
node.n = env.active_parent.n - env.active_parent.split_idx
197
192
198
- # with gil:
199
- # print("in _handle_set_active_parent")
200
- # print(f"data = {data.parent_node_id}")
201
- # print(f"env = {env.tree.size()}")
202
- # print(f"active_is_left = {env.active_is_left}")
203
- # print(f"node.start_idx = {node.start_idx}")
204
- # print(f"node.n = {node.n}")
205
-
206
193
(< Views> env.data_views).partitioner.init_node_split(node.start_idx, node.start_idx + node.n)
207
194
208
- # with gil:
209
- # print("returning")
210
- # print("")
211
-
212
195
return True
213
196
214
197
cdef class SetActiveParentHandler(EventHandler):
@@ -224,10 +207,6 @@ cdef bint _handle_sort_feature(
224
207
EventHandlerEnv handler_env,
225
208
EventData event_data
226
209
) noexcept nogil:
227
- # with gil:
228
- # print("")
229
- # print("in _handle_sort_feature")
230
-
231
210
if event_type != NodeSplitEvent.SORT_FEATURE:
232
211
return True
233
212
@@ -239,20 +218,11 @@ cdef bint _handle_sort_feature(
239
218
node.split_idx = 0
240
219
node.split_value = NAN
241
220
242
- # with gil:
243
- # print(f"data.feature = {data.feature}")
244
- # print(f"node.feature = {node.feature}")
245
- # print(f"node.split_idx = {node.split_idx}")
246
- # print(f"node.split_value = {node.split_value}")
247
-
248
221
(< Views> env.data_views).partitioner.sort_samples_and_feature_values(node.feature)
249
222
250
- # with gil:
251
- # print("returning")
252
- # print("")
253
-
254
223
return True
255
224
225
+ # When the structure tree sorts by a feature, we must do the same
256
226
cdef class NodeSortFeatureHandler(EventHandler):
257
227
def __cinit__ (self , Honesty h ):
258
228
self .event_types = np.array([NodeSplitEvent.SORT_FEATURE], dtype = np.int32)
@@ -266,15 +236,9 @@ cdef bint _handle_add_node(
266
236
EventHandlerEnv handler_env,
267
237
EventData event_data
268
238
) noexcept nogil:
269
- # with gil:
270
- # print("_handle_add_node checkpoint 1")
271
-
272
239
if event_type != TreeBuildEvent.ADD_NODE:
273
240
return True
274
241
275
- # with gil:
276
- # print("_handle_add_node checkpoint 2")
277
-
278
242
cdef HonestEnv* env = < HonestEnv* > handler_env
279
243
cdef const float32_t[:, :] X = (< Views> env.data_views).X
280
244
cdef intp_t[::1 ] samples = (< Views> env.data_views).samples
@@ -284,36 +248,15 @@ cdef bint _handle_add_node(
284
248
cdef Interval * interval = NULL
285
249
cdef Interval * parent = NULL
286
250
287
- # with gil:
288
- # print("_handle_add_node checkpoint 3")
289
-
290
251
if data.node_id >= size:
291
- # with gil:
292
- # print("resizing")
293
- # print(f"node_id = {data.node_id}")
294
- # print(f"old tree.size = {env.tree.size()}")
295
252
# as a heuristic, assume a complete tree and add a level
296
253
h = floor(fmax(0 , log2(size)))
297
254
env.tree.resize(size + < intp_t> pow (2 , h + 1 ))
298
255
299
- # with gil:
300
- # print(f"h = {h}")
301
- # print(f"log2(size) = {log2(size)}")
302
- # print(f"new size = {size + <intp_t>pow(2, h + 1)}")
303
- # print(f"new tree.size = {env.tree.size()}")
304
-
305
- # with gil:
306
- # print("_handle_add_node checkpoint 4")
307
- # print(f"node_id = {data.node_id}")
308
- # print(f"tree.size = {env.tree.size()}")
309
-
310
256
interval = & (env.tree[data.node_id])
311
257
interval.feature = data.feature
312
258
interval.split_value = data.split_point
313
259
314
- # with gil:
315
- # print("_handle_add_node checkpoint 5")
316
-
317
260
if data.parent_node_id < 0 :
318
261
# the node being added is the tree root
319
262
interval.start_idx = 0
@@ -328,34 +271,22 @@ cdef bint _handle_add_node(
328
271
interval.start_idx = parent.split_idx
329
272
interval.n = parent.n - (parent.split_idx - parent.start_idx)
330
273
331
- # with gil:
332
- # print("_handle_add_node checkpoint 6")
333
-
334
- # *we* don't need to sort to find the split pos we'll need for partitioning,
335
- # but the partitioner internals are so stateful we had better just do it
336
- # to ensure that it's in the expected state
274
+ # We also reuse Partitioner. *We* don't need to sort to find the split pos we'll
275
+ # need for partitioning, but the partitioner internals are so stateful we had
276
+ # better just do it to ensure that it's in the expected state
337
277
(< Views> env.data_views).partitioner.init_node_split(interval.start_idx, interval.start_idx + interval.n)
338
278
(< Views> env.data_views).partitioner.sort_samples_and_feature_values(interval.feature)
339
279
340
- # with gil:
341
- # print("_handle_add_node checkpoint 7")
342
-
343
280
# count n_left to find split pos
344
281
n_left = 0
345
282
i = interval.start_idx
346
283
feature_value = X[samples[i], interval.feature]
347
284
348
- # with gil:
349
- # print("_handle_add_node checkpoint 8")
350
-
351
285
while (not isnan(feature_value)) and feature_value < interval.split_value and i < interval.start_idx + interval.n:
352
286
n_left += 1
353
287
i += 1
354
288
feature_value = X[samples[i], interval.feature]
355
289
356
- # with gil:
357
- # print("_handle_add_node checkpoint 9")
358
-
359
290
interval.split_idx = interval.start_idx + n_left
360
291
361
292
(< Views> env.data_views).partitioner.partition_samples_final(
@@ -364,26 +295,6 @@ cdef bint _handle_add_node(
364
295
365
296
env.node_count += 1
366
297
367
- # with gil:
368
- # #print("_handle_add_node checkpoint 10")
369
- # print("")
370
- # print(f"parent_node_id = {data.parent_node_id}")
371
- # print(f"node_id = {data.node_id}")
372
- # print(f"is_leaf = {data.is_leaf}")
373
- # print(f"is_left = {data.is_left}")
374
- # print(f"feature = {data.feature}")
375
- # print(f"split_point = {data.split_point}")
376
- # print("---")
377
- # print(f"start_idx = {interval.start_idx}")
378
- # if parent is not NULL:
379
- # print(f"parent.start_idx = {parent.start_idx}")
380
- # print(f"parent.split_idx = {parent.split_idx}")
381
- # print(f"parent.n = {parent.n}")
382
- # print(f"n = {interval.n}")
383
- # print(f"feature = {interval.feature}")
384
- # print(f"split_idx = {interval.split_idx}")
385
- # print(f"split_value = {interval.split_value}")
386
-
387
298
388
299
cdef class AddNodeHandler(EventHandler):
389
300
def __cinit__ (self , Honesty h ):
@@ -404,9 +315,6 @@ cdef bint _trivial_condition(
404
315
float64_t upper_bound,
405
316
SplitConditionEnv split_condition_env
406
317
) noexcept nogil:
407
- # with gil:
408
- # print("TrivialCondition called")
409
-
410
318
return True
411
319
412
320
cdef class TrivialCondition(SplitCondition):
@@ -448,34 +356,16 @@ cdef bint _honest_min_sample_leaf_condition(
448
356
n_left = node.split_idx - node.start_idx
449
357
n_right = end_non_missing - node.split_idx + n_missing
450
358
451
- # with gil:
452
- # print("")
453
- # print("in _honest_min_sample_leaf_condition")
454
- # print(f"min_samples_leaf = {min_samples_leaf}")
455
- # print(f"feature = {node.feature}")
456
- # print(f"start_idx = {node.start_idx}")
457
- # print(f"split_idx = {node.split_idx}")
458
- # print(f"n = {node.n}")
459
- # print(f"n_missing = {n_missing}")
460
- # print(f"end_non_missing = {end_non_missing}")
461
- # print(f"n_left = {n_left}")
462
- # print(f"n_right = {n_right}")
463
- # print(f"split_value = {split_value}")
464
- # if node.split_idx > 0:
465
- # print(f"X.feature_value left = {(<Views>env.honest_env.data_views).X[(<Views>env.honest_env.data_views).samples[node.split_idx - 1], node.feature]}")
466
- # print(f"X.feature_value right = {(<Views>env.honest_env.data_views).X[(<Views>env.honest_env.data_views).samples[node.split_idx], node.feature]}")
467
-
468
359
# Reject if min_samples_leaf is not guaranteed
469
360
if n_left < min_samples_leaf or n_right < min_samples_leaf:
470
361
# with gil:
471
362
# print("returning False")
472
363
return False
473
364
474
- # with gil:
475
- # print("returning True")
476
-
477
365
return True
478
366
367
+ # Check that the honest set will have sufficient samples on each side of this
368
+ # candidate split.
479
369
cdef class HonestMinSamplesLeafCondition(SplitCondition):
480
370
def __cinit__ (self , Honesty h , intp_t min_samples ):
481
371
self ._env.min_samples = min_samples
0 commit comments