Skip to content

Commit 6801508

Browse files
PSSF23adam2392
andauthored
FIX correct node splitting order & remove class weight (#56)
<!-- Thanks for contributing a pull request! Please ensure you have taken a look at the contribution guidelines: https://github.com/scikit-learn/scikit-learn/blob/main/CONTRIBUTING.md --> #### Reference Issues/PRs <!-- Example: Fixes scikit-learn#1234. See also scikit-learn#3456. Please use keywords (e.g., Fixes) to create link to the issues or pull requests you resolved, so that they will automatically be closed when your pull request is merged. See https://github.com/blog/1506-closing-issues-via-pull-requests --> neurodata/treeple#107 #### What does this implement/fix? Explain your changes. #### Any other comments? <!-- Please be aware that we are a loose team of volunteers so patience is necessary; assistance handling other issues is very welcome. We value all user contributions, no matter how minor they are. If we are slow to review, either the pull request needs some benchmarking, tinkering, convincing, etc. or more likely the reviewers are simply busy. In either case, we ask for your understanding during the review process. For more information, see our FAQ on this topic: http://scikit-learn.org/dev/faq.html#why-is-my-pull-request-not-getting-any-attention. Thanks for contributing! --> --------- Signed-off-by: Adam Li <adam2392@gmail.com> Co-authored-by: Adam Li <adam2392@gmail.com>
1 parent 4424f98 commit 6801508

File tree

4 files changed

+22
-19
lines changed

4 files changed

+22
-19
lines changed

sklearn/ensemble/_forest.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1252,7 +1252,11 @@ def partial_fit(self, X, y, sample_weight=None, classes=None):
12521252

12531253
self.n_outputs_ = y.shape[1]
12541254

1255-
y, expanded_class_weight = self._validate_y_class_weight(y)
1255+
classes = self.classes_
1256+
if self.n_outputs_ == 1:
1257+
classes = [classes]
1258+
1259+
y, expanded_class_weight = self._validate_y_class_weight(y, classes)
12561260

12571261
if getattr(y, "dtype", None) != DOUBLE or not y.flags.contiguous:
12581262
y = np.ascontiguousarray(y, dtype=DOUBLE)
@@ -1341,7 +1345,7 @@ def partial_fit(self, X, y, sample_weight=None, classes=None):
13411345
verbose=self.verbose,
13421346
class_weight=self.class_weight,
13431347
n_samples_bootstrap=n_samples_bootstrap,
1344-
classes=classes,
1348+
classes=classes[0],
13451349
)
13461350
for i, t in enumerate(self.estimators_)
13471351
)

sklearn/tree/_splitter.pxd

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ cdef class Splitter(BaseSplitter):
126126
# Methods that allow modifications to stopping conditions
127127
cdef bint check_presplit_conditions(
128128
self,
129-
SplitRecord current_split,
129+
SplitRecord* current_split,
130130
SIZE_t n_missing,
131131
bint missing_go_to_left,
132132
) noexcept nogil

sklearn/tree/_splitter.pyx

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -327,9 +327,9 @@ cdef class Splitter(BaseSplitter):
327327

328328
return self.criterion.node_impurity()
329329

330-
cdef bint check_presplit_conditions(
330+
cdef inline bint check_presplit_conditions(
331331
self,
332-
SplitRecord current_split,
332+
SplitRecord* current_split,
333333
SIZE_t n_missing,
334334
bint missing_go_to_left,
335335
) noexcept nogil:
@@ -356,7 +356,7 @@ cdef class Splitter(BaseSplitter):
356356

357357
return 0
358358

359-
cdef bint check_postsplit_conditions(
359+
cdef inline bint check_postsplit_conditions(
360360
self
361361
) noexcept nogil:
362362
"""Check stopping conditions after evaluating the split.
@@ -571,7 +571,7 @@ cdef inline int node_split_best(
571571
else:
572572
n_left = current_split.pos - splitter.start
573573
n_right = end_non_missing - current_split.pos + n_missing
574-
if splitter.check_presplit_conditions(current_split, n_missing, missing_go_to_left) == 1:
574+
if splitter.check_presplit_conditions(&current_split, n_missing, missing_go_to_left) == 1:
575575
continue
576576

577577
criterion.update(current_split.pos)
@@ -914,7 +914,7 @@ cdef inline int node_split_random(
914914
current_split.pos = partitioner.partition_samples(current_split.threshold)
915915

916916
# Reject if min_samples_leaf is not guaranteed
917-
if splitter.check_presplit_conditions(current_split, 0, 0) == 1:
917+
if splitter.check_presplit_conditions(&current_split, 0, 0) == 1:
918918
continue
919919

920920
# Evaluate split

sklearn/tree/_tree.pyx

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -268,16 +268,6 @@ cdef class DepthFirstTreeBuilder(TreeBuilder):
268268
# check input
269269
X, y, sample_weight = self._check_input(X, y, sample_weight)
270270

271-
# Initial capacity
272-
cdef int init_capacity
273-
274-
if tree.max_depth <= 10:
275-
init_capacity = <int> (2 ** (tree.max_depth + 1)) - 1
276-
else:
277-
init_capacity = 2047
278-
279-
tree._resize(init_capacity)
280-
281271
# Parameters
282272
cdef Splitter splitter = self.splitter
283273
cdef SIZE_t max_depth = self.max_depth
@@ -286,10 +276,19 @@ cdef class DepthFirstTreeBuilder(TreeBuilder):
286276
cdef SIZE_t min_samples_split = self.min_samples_split
287277
cdef double min_impurity_decrease = self.min_impurity_decrease
288278

279+
# Initial capacity
280+
cdef int init_capacity
289281
cdef bint first = 0
290282
if self.initial_roots is None:
291283
# Recursive partition (without actual recursion)
292284
splitter.init(X, y, sample_weight, missing_values_in_feature_mask)
285+
286+
if tree.max_depth <= 10:
287+
init_capacity = <int> (2 ** (tree.max_depth + 1)) - 1
288+
else:
289+
init_capacity = 2047
290+
291+
tree._resize(init_capacity)
293292
first = 1
294293

295294
cdef SIZE_t start = 0
@@ -1148,7 +1147,7 @@ cdef class BaseTree:
11481147

11491148
return node_id
11501149

1151-
cdef SIZE_t _update_node(
1150+
cdef inline SIZE_t _update_node(
11521151
self,
11531152
SIZE_t parent,
11541153
bint is_left,

0 commit comments

Comments
 (0)