Skip to content

Commit f655401

Browse files
commented changes to splitter
1 parent 5291fb1 commit f655401

File tree

3 files changed

+34
-107
lines changed

3 files changed

+34
-107
lines changed

sklearn/tree/_events.pyx

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,10 +50,6 @@ cdef class EventBroker:
5050
cdef bint fire_event(self, EventType event_type, EventData event_data) noexcept nogil:
5151
cdef bint result = True
5252

53-
#with gil:
54-
# print(f"firing event {event_type}")
55-
# print(f"listeners.size = {self.listeners.size()}")
56-
5753
if event_type < self.listeners.size():
5854
for l in self.listeners[event_type]:
5955
result = result and l.f(event_type, l.e, event_data)

sklearn/tree/_splitter.pxd

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -34,17 +34,14 @@ cdef struct NodeSplitEventData:
3434
intp_t feature
3535
float64_t threshold
3636

37-
# NICE IDEAS THAT DON'T APPEAR POSSIBLE
38-
# - accessing elements of a memory view of cython extension types in a nogil block/function
39-
# - storing cython extension types in cpp vectors
40-
#
41-
# despite the fact that we can access scalar extension type properties in such a context,
42-
# as for instance node_split_best does with Criterion and Partition,
43-
# and we can access the elements of a memory view of primitive types in such a context
44-
#
45-
# SO WHERE DOES THAT LEAVE US
46-
# - we can transform these into cpp vectors of structs
47-
# and with some minor casting irritations everything else works ok
37+
# We wish to generalize Splitter so that arbitrary split rejection criteria can be
38+
# passed in dynamically at construction. The natural way to want to do this is to
39+
# pass in a list of lambdas, but as we are in cython, this is not so straightforward.
40+
# We want the convience of being able to pass them in as a python list, and while it
41+
# would be nice to receive them as a memoryview, this is quite a nuisance with
42+
# cython extension types, so we do cpp vector instead. We do the same closure struct
43+
# pattern for execution speed, but they need to be wrapped in cython extension types
44+
# both for convenience and to go in python list.
4845
ctypedef void* SplitConditionEnv
4946
ctypedef bint (*SplitConditionFunction)(
5047
Splitter splitter,
@@ -79,6 +76,12 @@ cdef struct SplitRecord:
7976
unsigned char missing_go_to_left # Controls if missing values go to the left node.
8077
intp_t n_missing # Number of missing values for the feature being split on
8178

79+
80+
# In the neurodata fork of sklearn there was a hack added where SplitRecords are
81+
# created which queries splitter for pointer size and does an inline malloc. This
82+
# is to accommodate the ability to create extended SplitRecord types in Splitter
83+
# subclasses. We refactor that into a factory method again implemented as a closure
84+
# struct.
8285
ctypedef void* SplitRecordFactoryEnv
8386
ctypedef SplitRecord* (*SplitRecordFactory)(SplitRecordFactoryEnv env) except NULL nogil
8487

@@ -168,9 +171,13 @@ cdef class Splitter(BaseSplitter):
168171
cdef SplitCondition min_weight_leaf_condition
169172
cdef SplitCondition monotonic_constraint_condition
170173

174+
# split rejection criteria checked before split selection
171175
cdef vector[SplitConditionClosure] presplit_conditions
176+
177+
# split rejection criteria checked after split selection
172178
cdef vector[SplitConditionClosure] postsplit_conditions
173179

180+
# event broker for handling splitter events
174181
cdef EventBroker event_broker
175182

176183
cdef int init(

sklearn/tree/_splitter.pyx

Lines changed: 16 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@ import numpy as np
3333
cdef float64_t INFINITY = np.inf
3434

3535

36+
# we refactor the inline min sample leaf split rejection criterion
37+
# into our injectable SplitCondition pattern
3638
cdef bint min_sample_leaf_condition(
3739
Splitter splitter,
3840
intp_t split_feature,
@@ -66,6 +68,9 @@ cdef class MinSamplesLeafCondition(SplitCondition):
6668
self.c.f = min_sample_leaf_condition
6769
self.c.e = NULL # min_samples is stored in splitter, which is already passed to f
6870

71+
72+
# we refactor the inline min weight leaf split rejection criterion
73+
# into our injectable SplitCondition pattern
6974
cdef bint min_weight_leaf_condition(
7075
Splitter splitter,
7176
intp_t split_feature,
@@ -91,6 +96,9 @@ cdef class MinWeightLeafCondition(SplitCondition):
9196
self.c.f = min_weight_leaf_condition
9297
self.c.e = NULL # min_weight_leaf is stored in splitter, which is already passed to f
9398

99+
100+
# we refactor the inline monotonic constraint split rejection criterion
101+
# into our injectable SplitCondition pattern
94102
cdef bint monotonic_constraint_condition(
95103
Splitter splitter,
96104
intp_t split_feature,
@@ -131,6 +139,7 @@ cdef inline void _init_split(SplitRecord* self, intp_t start_pos) noexcept nogil
131139
self.missing_go_to_left = False
132140
self.n_missing = 0
133141

142+
# the default SplitRecord factory method simply mallocs a SplitRecord
134143
cdef SplitRecord* _base_split_record_factory(SplitRecordFactoryEnv env) except NULL nogil:
135144
return <SplitRecord*>malloc(sizeof(SplitRecord));
136145

@@ -281,20 +290,6 @@ cdef class Splitter(BaseSplitter):
281290
self.min_samples_leaf_condition = MinSamplesLeafCondition()
282291
self.min_weight_leaf_condition = MinWeightLeafCondition()
283292

284-
#self.presplit_conditions.resize(
285-
# (len(presplit_conditions) if presplit_conditions is not None else 0)
286-
# + (2 if self.with_monotonic_cst else 1)
287-
#)
288-
#self.postsplit_conditions.resize(
289-
# (len(postsplit_conditions) if postsplit_conditions is not None else 0)
290-
# + (2 if self.with_monotonic_cst else 1)
291-
#)
292-
293-
#cdef int offset = 0
294-
#self.presplit_conditions[offset] = self.min_samples_leaf_condition.c
295-
#self.postsplit_conditions[offset] = self.min_weight_leaf_condition.c
296-
#offset += 1
297-
298293
l_pre = [self.min_samples_leaf_condition]
299294
l_post = [self.min_weight_leaf_condition]
300295

@@ -306,16 +301,11 @@ cdef class Splitter(BaseSplitter):
306301
#self.postsplit_conditions[offset] = self.monotonic_constraint_condition.c
307302
#offset += 1
308303

309-
#cdef int i
310304
if presplit_conditions is not None:
311305
l_pre += presplit_conditions
312-
#for i in range(len(presplit_conditions)):
313-
# self.presplit_conditions[i + offset] = presplit_conditions[i].c
314306

315307
if postsplit_conditions is not None:
316308
l_post += postsplit_conditions
317-
#for i in range(len(postsplit_conditions)):
318-
# self.postsplit_conditions[i + offset] = postsplit_conditions[i].c
319309

320310
self.presplit_conditions.resize(0)
321311
self.add_presplit_conditions(l_pre)
@@ -595,10 +585,6 @@ cdef inline intp_t node_split_best(
595585
Returns -1 in case of failure to allocate memory (and raise MemoryError)
596586
or 0 otherwise.
597587
"""
598-
#with gil:
599-
# print("")
600-
# print("in node_split_best")
601-
602588
cdef const int8_t[:] monotonic_cst = splitter.monotonic_cst
603589
cdef bint with_monotonic_cst = splitter.with_monotonic_cst
604590

@@ -648,19 +634,14 @@ cdef inline intp_t node_split_best(
648634

649635
cdef bint conditions_hold = True
650636

637+
# payloads for different node events
651638
cdef NodeSortFeatureEventData sort_event_data
652639
cdef NodeSplitEventData split_event_data
653640

654-
#with gil:
655-
# print("checkpoint 1")
656-
657641
_init_split(&best_split, end)
658642

659643
partitioner.init_node_split(start, end)
660644

661-
#with gil:
662-
# print("checkpoint 2")
663-
664645
# Sample up to max_features without replacement using a
665646
# Fisher-Yates-based algorithm (using the local variables `f_i` and
666647
# `f_j` to compute a permutation of the `features` array).
@@ -706,6 +687,7 @@ cdef inline intp_t node_split_best(
706687
current_split.feature = features[f_j]
707688
partitioner.sort_samples_and_feature_values(current_split.feature)
708689

690+
# notify any interested parties which feature we're investingating splits for now
709691
sort_event_data.feature = current_split.feature
710692
splitter.event_broker.fire_event(NodeSplitEvent.SORT_FEATURE, &sort_event_data)
711693

@@ -741,46 +723,28 @@ cdef inline intp_t node_split_best(
741723
n_searches = 2 if has_missing else 1
742724

743725
for i in range(n_searches):
744-
#with gil:
745-
# print(f"search {i}")
746-
747726
missing_go_to_left = i == 1
748727
criterion.missing_go_to_left = missing_go_to_left
749728
criterion.reset()
750729

751730
p = start
752731

753732
while p < end_non_missing:
754-
#with gil:
755-
# print("")
756-
# print("_node_split_best checkpoint 1")
757-
758733
partitioner.next_p(&p_prev, &p)
759734

760-
#with gil:
761-
# print("checkpoint 1.1")
762-
# print(f"end_non_missing = {end_non_missing}")
763-
# print(f"p = {<int32_t>p}")
764-
765735
if p >= end_non_missing:
766-
#with gil:
767-
# print("continuing")
768736
continue
769737

770-
#with gil:
771-
# print("_node_split_best checkpoint 1.2")
772-
773738
current_split.pos = p
739+
774740
# probably want to assign this to current_split.threshold later,
775741
# but the code is so stateful that Write Everything Twice is the
776742
# safer move here for now
777743
current_threshold = (
778744
feature_values[p_prev] / 2.0 + feature_values[p] / 2.0
779745
)
780746

781-
#with gil:
782-
# print("_node_split_best checkpoint 2")
783-
747+
# check pre split rejection criteria
784748
conditions_hold = True
785749
for condition in splitter.presplit_conditions:
786750
if not condition.f(
@@ -791,24 +755,18 @@ cdef inline intp_t node_split_best(
791755
conditions_hold = False
792756
break
793757

794-
#with gil:
795-
# print("_node_split_best checkpoint 3")
796-
797758
if not conditions_hold:
798759
continue
799760

800761
# Reject if min_samples_leaf is not guaranteed
762+
# this can probably (and should) be removed as it is generalized
763+
# by injectable split rejection criteria
801764
if splitter.check_presplit_conditions(&current_split, n_missing, missing_go_to_left) == 1:
802765
continue
803766

804-
#with gil:
805-
# print("_node_split_best checkpoint 4")
806-
807767
criterion.update(current_split.pos)
808768

809-
#with gil:
810-
# print("_node_split_best checkpoint 5")
811-
769+
# check post split rejection criteria
812770
conditions_hold = True
813771
for condition in splitter.postsplit_conditions:
814772
if not condition.f(
@@ -819,15 +777,9 @@ cdef inline intp_t node_split_best(
819777
conditions_hold = False
820778
break
821779

822-
#with gil:
823-
# print("_node_split_best checkpoint 6")
824-
825780
if not conditions_hold:
826781
continue
827782

828-
#with gil:
829-
# print("_node_split_best checkpoint 7")
830-
831783
current_proxy_improvement = criterion.proxy_impurity_improvement()
832784

833785
if current_proxy_improvement > best_proxy_improvement:
@@ -859,15 +811,9 @@ cdef inline intp_t node_split_best(
859811

860812
best_split = current_split # copy
861813

862-
#with gil:
863-
# print("_node_split_best checkpoint 8")
864-
865814
# Evaluate when there are missing values and all missing values goes
866815
# to the right node and non-missing values goes to the left node.
867816
if has_missing:
868-
#with gil:
869-
# print("has_missing = {has_missing}")
870-
871817
n_left, n_right = end - start - n_missing, n_missing
872818
p = end - n_missing
873819
missing_go_to_left = 0
@@ -888,24 +834,16 @@ cdef inline intp_t node_split_best(
888834
current_split.pos = p
889835
best_split = current_split
890836

891-
#with gil:
892-
# print("checkpoint 9")
893837

894838
# Reorganize into samples[start:best_split.pos] + samples[best_split.pos:end]
895839
if best_split.pos < end:
896-
#with gil:
897-
# print("checkpoint 10")
898-
899840
partitioner.partition_samples_final(
900841
best_split.pos,
901842
best_split.threshold,
902843
best_split.feature,
903844
best_split.n_missing
904845
)
905846

906-
#with gil:
907-
# print("checkpoint 11")
908-
909847
criterion.init_missing(best_split.n_missing)
910848
criterion.missing_go_to_left = best_split.missing_go_to_left
911849

@@ -920,37 +858,23 @@ cdef inline intp_t node_split_best(
920858
best_split.impurity_right
921859
)
922860

923-
#with gil:
924-
# print("checkpoint 12")
925-
926861
shift_missing_values_to_left_if_required(&best_split, samples, end)
927862

928-
#with gil:
929-
# print("checkpoint 13")
930863

931864
# Respect invariant for constant features: the original order of
932865
# element in features[:n_known_constants] must be preserved for sibling
933866
# and child nodes
934867
memcpy(&features[0], &constant_features[0], sizeof(intp_t) * n_known_constants)
935868

936-
#with gil:
937-
# print("checkpoint 14")
938-
939869
# Copy newly found constant features
940870
memcpy(&constant_features[n_known_constants],
941871
&features[n_known_constants],
942872
sizeof(intp_t) * n_found_constants)
943873

944-
#with gil:
945-
# print("checkpoint 15")
946-
947874
# Return values
948875
parent_record.n_constant_features = n_total_constants
949876
split[0] = best_split
950877

951-
#with gil:
952-
# print("returning")
953-
954878
return 0
955879

956880

0 commit comments

Comments
 (0)