@@ -33,6 +33,8 @@ import numpy as np
33
33
cdef float64_t INFINITY = np.inf
34
34
35
35
36
+ # we refactor the inline min sample leaf split rejection criterion
37
+ # into our injectable SplitCondition pattern
36
38
cdef bint min_sample_leaf_condition(
37
39
Splitter splitter,
38
40
intp_t split_feature,
@@ -66,6 +68,9 @@ cdef class MinSamplesLeafCondition(SplitCondition):
66
68
self .c.f = min_sample_leaf_condition
67
69
self .c.e = NULL # min_samples is stored in splitter, which is already passed to f
68
70
71
+
72
+ # we refactor the inline min weight leaf split rejection criterion
73
+ # into our injectable SplitCondition pattern
69
74
cdef bint min_weight_leaf_condition(
70
75
Splitter splitter,
71
76
intp_t split_feature,
@@ -91,6 +96,9 @@ cdef class MinWeightLeafCondition(SplitCondition):
91
96
self .c.f = min_weight_leaf_condition
92
97
self .c.e = NULL # min_weight_leaf is stored in splitter, which is already passed to f
93
98
99
+
100
+ # we refactor the inline monotonic constraint split rejection criterion
101
+ # into our injectable SplitCondition pattern
94
102
cdef bint monotonic_constraint_condition(
95
103
Splitter splitter,
96
104
intp_t split_feature,
@@ -131,6 +139,7 @@ cdef inline void _init_split(SplitRecord* self, intp_t start_pos) noexcept nogil
131
139
self .missing_go_to_left = False
132
140
self .n_missing = 0
133
141
142
+ # the default SplitRecord factory method simply mallocs a SplitRecord
134
143
cdef SplitRecord* _base_split_record_factory(SplitRecordFactoryEnv env) except NULL nogil:
135
144
return < SplitRecord* > malloc(sizeof(SplitRecord));
136
145
@@ -281,20 +290,6 @@ cdef class Splitter(BaseSplitter):
281
290
self .min_samples_leaf_condition = MinSamplesLeafCondition()
282
291
self .min_weight_leaf_condition = MinWeightLeafCondition()
283
292
284
- # self.presplit_conditions.resize(
285
- # (len(presplit_conditions) if presplit_conditions is not None else 0)
286
- # + (2 if self.with_monotonic_cst else 1)
287
- # )
288
- # self.postsplit_conditions.resize(
289
- # (len(postsplit_conditions) if postsplit_conditions is not None else 0)
290
- # + (2 if self.with_monotonic_cst else 1)
291
- # )
292
-
293
- # cdef int offset = 0
294
- # self.presplit_conditions[offset] = self.min_samples_leaf_condition.c
295
- # self.postsplit_conditions[offset] = self.min_weight_leaf_condition.c
296
- # offset += 1
297
-
298
293
l_pre = [self .min_samples_leaf_condition]
299
294
l_post = [self .min_weight_leaf_condition]
300
295
@@ -306,16 +301,11 @@ cdef class Splitter(BaseSplitter):
306
301
# self.postsplit_conditions[offset] = self.monotonic_constraint_condition.c
307
302
# offset += 1
308
303
309
- # cdef int i
310
304
if presplit_conditions is not None :
311
305
l_pre += presplit_conditions
312
- # for i in range(len(presplit_conditions)):
313
- # self.presplit_conditions[i + offset] = presplit_conditions[i].c
314
306
315
307
if postsplit_conditions is not None :
316
308
l_post += postsplit_conditions
317
- # for i in range(len(postsplit_conditions)):
318
- # self.postsplit_conditions[i + offset] = postsplit_conditions[i].c
319
309
320
310
self .presplit_conditions.resize(0 )
321
311
self .add_presplit_conditions(l_pre)
@@ -595,10 +585,6 @@ cdef inline intp_t node_split_best(
595
585
Returns -1 in case of failure to allocate memory (and raise MemoryError)
596
586
or 0 otherwise.
597
587
"""
598
- # with gil:
599
- # print("")
600
- # print("in node_split_best")
601
-
602
588
cdef const int8_t[:] monotonic_cst = splitter.monotonic_cst
603
589
cdef bint with_monotonic_cst = splitter.with_monotonic_cst
604
590
@@ -648,19 +634,14 @@ cdef inline intp_t node_split_best(
648
634
649
635
cdef bint conditions_hold = True
650
636
637
+ # payloads for different node events
651
638
cdef NodeSortFeatureEventData sort_event_data
652
639
cdef NodeSplitEventData split_event_data
653
640
654
- # with gil:
655
- # print("checkpoint 1")
656
-
657
641
_init_split(& best_split, end)
658
642
659
643
partitioner.init_node_split(start, end)
660
644
661
- # with gil:
662
- # print("checkpoint 2")
663
-
664
645
# Sample up to max_features without replacement using a
665
646
# Fisher-Yates-based algorithm (using the local variables `f_i` and
666
647
# `f_j` to compute a permutation of the `features` array).
@@ -706,6 +687,7 @@ cdef inline intp_t node_split_best(
706
687
current_split.feature = features[f_j]
707
688
partitioner.sort_samples_and_feature_values(current_split.feature)
708
689
690
+ # notify any interested parties which feature we're investingating splits for now
709
691
sort_event_data.feature = current_split.feature
710
692
splitter.event_broker.fire_event(NodeSplitEvent.SORT_FEATURE, & sort_event_data)
711
693
@@ -741,46 +723,28 @@ cdef inline intp_t node_split_best(
741
723
n_searches = 2 if has_missing else 1
742
724
743
725
for i in range (n_searches):
744
- # with gil:
745
- # print(f"search {i}")
746
-
747
726
missing_go_to_left = i == 1
748
727
criterion.missing_go_to_left = missing_go_to_left
749
728
criterion.reset()
750
729
751
730
p = start
752
731
753
732
while p < end_non_missing:
754
- # with gil:
755
- # print("")
756
- # print("_node_split_best checkpoint 1")
757
-
758
733
partitioner.next_p(& p_prev, & p)
759
734
760
- # with gil:
761
- # print("checkpoint 1.1")
762
- # print(f"end_non_missing = {end_non_missing}")
763
- # print(f"p = {<int32_t>p}")
764
-
765
735
if p >= end_non_missing:
766
- # with gil:
767
- # print("continuing")
768
736
continue
769
737
770
- # with gil:
771
- # print("_node_split_best checkpoint 1.2")
772
-
773
738
current_split.pos = p
739
+
774
740
# probably want to assign this to current_split.threshold later,
775
741
# but the code is so stateful that Write Everything Twice is the
776
742
# safer move here for now
777
743
current_threshold = (
778
744
feature_values[p_prev] / 2.0 + feature_values[p] / 2.0
779
745
)
780
746
781
- # with gil:
782
- # print("_node_split_best checkpoint 2")
783
-
747
+ # check pre split rejection criteria
784
748
conditions_hold = True
785
749
for condition in splitter.presplit_conditions:
786
750
if not condition.f(
@@ -791,24 +755,18 @@ cdef inline intp_t node_split_best(
791
755
conditions_hold = False
792
756
break
793
757
794
- # with gil:
795
- # print("_node_split_best checkpoint 3")
796
-
797
758
if not conditions_hold:
798
759
continue
799
760
800
761
# Reject if min_samples_leaf is not guaranteed
762
+ # this can probably (and should) be removed as it is generalized
763
+ # by injectable split rejection criteria
801
764
if splitter.check_presplit_conditions(& current_split, n_missing, missing_go_to_left) == 1 :
802
765
continue
803
766
804
- # with gil:
805
- # print("_node_split_best checkpoint 4")
806
-
807
767
criterion.update(current_split.pos)
808
768
809
- # with gil:
810
- # print("_node_split_best checkpoint 5")
811
-
769
+ # check post split rejection criteria
812
770
conditions_hold = True
813
771
for condition in splitter.postsplit_conditions:
814
772
if not condition.f(
@@ -819,15 +777,9 @@ cdef inline intp_t node_split_best(
819
777
conditions_hold = False
820
778
break
821
779
822
- # with gil:
823
- # print("_node_split_best checkpoint 6")
824
-
825
780
if not conditions_hold:
826
781
continue
827
782
828
- # with gil:
829
- # print("_node_split_best checkpoint 7")
830
-
831
783
current_proxy_improvement = criterion.proxy_impurity_improvement()
832
784
833
785
if current_proxy_improvement > best_proxy_improvement:
@@ -859,15 +811,9 @@ cdef inline intp_t node_split_best(
859
811
860
812
best_split = current_split # copy
861
813
862
- # with gil:
863
- # print("_node_split_best checkpoint 8")
864
-
865
814
# Evaluate when there are missing values and all missing values goes
866
815
# to the right node and non-missing values goes to the left node.
867
816
if has_missing:
868
- # with gil:
869
- # print("has_missing = {has_missing}")
870
-
871
817
n_left, n_right = end - start - n_missing, n_missing
872
818
p = end - n_missing
873
819
missing_go_to_left = 0
@@ -888,24 +834,16 @@ cdef inline intp_t node_split_best(
888
834
current_split.pos = p
889
835
best_split = current_split
890
836
891
- # with gil:
892
- # print("checkpoint 9")
893
837
894
838
# Reorganize into samples[start:best_split.pos] + samples[best_split.pos:end]
895
839
if best_split.pos < end:
896
- # with gil:
897
- # print("checkpoint 10")
898
-
899
840
partitioner.partition_samples_final(
900
841
best_split.pos,
901
842
best_split.threshold,
902
843
best_split.feature,
903
844
best_split.n_missing
904
845
)
905
846
906
- # with gil:
907
- # print("checkpoint 11")
908
-
909
847
criterion.init_missing(best_split.n_missing)
910
848
criterion.missing_go_to_left = best_split.missing_go_to_left
911
849
@@ -920,37 +858,23 @@ cdef inline intp_t node_split_best(
920
858
best_split.impurity_right
921
859
)
922
860
923
- # with gil:
924
- # print("checkpoint 12")
925
-
926
861
shift_missing_values_to_left_if_required(& best_split, samples, end)
927
862
928
- # with gil:
929
- # print("checkpoint 13")
930
863
931
864
# Respect invariant for constant features: the original order of
932
865
# element in features[:n_known_constants] must be preserved for sibling
933
866
# and child nodes
934
867
memcpy(& features[0 ], & constant_features[0 ], sizeof(intp_t) * n_known_constants)
935
868
936
- # with gil:
937
- # print("checkpoint 14")
938
-
939
869
# Copy newly found constant features
940
870
memcpy(& constant_features[n_known_constants],
941
871
& features[n_known_constants],
942
872
sizeof(intp_t) * n_found_constants)
943
873
944
- # with gil:
945
- # print("checkpoint 15")
946
-
947
874
# Return values
948
875
parent_record.n_constant_features = n_total_constants
949
876
split[0 ] = best_split
950
877
951
- # with gil:
952
- # print("returning")
953
-
954
878
return 0
955
879
956
880
0 commit comments