Skip to content

Commit 64f20aa

Browse files
authored
MAINT Simplify node split Cython API (scikit-learn#29322)
Signed-off-by: Adam Li <adam2392@gmail.com>
1 parent ce2d74c commit 64f20aa

File tree

1 file changed

+6
-12
lines changed

1 file changed

+6
-12
lines changed

sklearn/tree/_splitter.pyx

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -286,14 +286,15 @@ cdef inline int node_split_best(
286286
Criterion criterion,
287287
SplitRecord* split,
288288
ParentInfo* parent_record,
289-
bint with_monotonic_cst,
290-
const int8_t[:] monotonic_cst,
291289
) except -1 nogil:
292290
"""Find the best split on node samples[start:end]
293291
294292
Returns -1 in case of failure to allocate memory (and raise MemoryError)
295293
or 0 otherwise.
296294
"""
295+
cdef const int8_t[:] monotonic_cst = splitter.monotonic_cst
296+
cdef bint with_monotonic_cst = splitter.with_monotonic_cst
297+
297298
# Find the best split
298299
cdef intp_t start = splitter.start
299300
cdef intp_t end = splitter.end
@@ -667,14 +668,15 @@ cdef inline int node_split_random(
667668
Criterion criterion,
668669
SplitRecord* split,
669670
ParentInfo* parent_record,
670-
bint with_monotonic_cst,
671-
const int8_t[:] monotonic_cst,
672671
) except -1 nogil:
673672
"""Find the best random split on node samples[start:end]
674673
675674
Returns -1 in case of failure to allocate memory (and raise MemoryError)
676675
or 0 otherwise.
677676
"""
677+
cdef const int8_t[:] monotonic_cst = splitter.monotonic_cst
678+
cdef bint with_monotonic_cst = splitter.with_monotonic_cst
679+
678680
# Draw random splits and pick the best
679681
cdef intp_t start = splitter.start
680682
cdef intp_t end = splitter.end
@@ -1512,8 +1514,6 @@ cdef class BestSplitter(Splitter):
15121514
self.criterion,
15131515
split,
15141516
parent_record,
1515-
self.with_monotonic_cst,
1516-
self.monotonic_cst,
15171517
)
15181518

15191519
cdef class BestSparseSplitter(Splitter):
@@ -1542,8 +1542,6 @@ cdef class BestSparseSplitter(Splitter):
15421542
self.criterion,
15431543
split,
15441544
parent_record,
1545-
self.with_monotonic_cst,
1546-
self.monotonic_cst,
15471545
)
15481546

15491547
cdef class RandomSplitter(Splitter):
@@ -1572,8 +1570,6 @@ cdef class RandomSplitter(Splitter):
15721570
self.criterion,
15731571
split,
15741572
parent_record,
1575-
self.with_monotonic_cst,
1576-
self.monotonic_cst,
15771573
)
15781574

15791575
cdef class RandomSparseSplitter(Splitter):
@@ -1601,6 +1597,4 @@ cdef class RandomSparseSplitter(Splitter):
16011597
self.criterion,
16021598
split,
16031599
parent_record,
1604-
self.with_monotonic_cst,
1605-
self.monotonic_cst,
16061600
)

0 commit comments

Comments
 (0)