Skip to content

Commit ccc8a4d

Browse files
PSSF23adam2392
andcommitted
WIP correct CI errors relating to partial_fit (#58)
<!-- Thanks for contributing a pull request! Please ensure you have taken a look at the contribution guidelines: https://github.com/scikit-learn/scikit-learn/blob/main/CONTRIBUTING.md --> #### Reference Issues/PRs <!-- Example: Fixes scikit-learn#1234. See also scikit-learn#3456. Please use keywords (e.g., Fixes) to create link to the issues or pull requests you resolved, so that they will automatically be closed when your pull request is merged. See https://github.com/blog/1506-closing-issues-via-pull-requests --> Fixes #57 #### What does this implement/fix? Explain your changes. #### Any other comments? <!-- Please be aware that we are a loose team of volunteers so patience is necessary; assistance handling other issues is very welcome. We value all user contributions, no matter how minor they are. If we are slow to review, either the pull request needs some benchmarking, tinkering, convincing, etc. or more likely the reviewers are simply busy. In either case, we ask for your understanding during the review process. For more information, see our FAQ on this topic: http://scikit-learn.org/dev/faq.html#why-is-my-pull-request-not-getting-any-attention. Thanks for contributing! --> --------- Co-authored-by: Adam Li <adam2392@gmail.com>
1 parent d27de49 commit ccc8a4d

File tree

2 files changed

+32
-20
lines changed

2 files changed

+32
-20
lines changed

sklearn/tree/_tree.pxd

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ cdef class TreeBuilder:
163163
cdef double min_weight_leaf # Minimum weight in a leaf
164164
cdef SIZE_t max_depth # Maximal tree depth
165165
cdef double min_impurity_decrease # Impurity threshold for early stopping
166-
cdef object initial_roots # Leaf nodes for streaming updates
166+
cdef cnp.ndarray initial_roots # Leaf nodes for streaming updates
167167

168168
cdef unsigned char store_leaf_values # Whether to store leaf values
169169

sklearn/tree/_tree.pyx

Lines changed: 31 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -220,19 +220,24 @@ cdef class DepthFirstTreeBuilder(TreeBuilder):
220220
X_copy = {}
221221
y_copy = {}
222222
for i in range(X.shape[0]):
223+
# collect depths from the node paths
223224
depth_i = paths[i].indices.shape[0] - 1
224225
PARENT = depth_i - 1
225226
CHILD = depth_i
226227

228+
# find leaf node's & their parent node's IDs
227229
if PARENT < 0:
228230
parent_i = 0
229231
else:
230232
parent_i = paths[i].indices[PARENT]
231233
child_i = paths[i].indices[CHILD]
232234
left = 0
233235
if tree.children_left[parent_i] == child_i:
234-
left = 1
236+
left = 1 # leaf node is left child
235237

238+
# organize samples by the leaf they fall into (false root)
239+
# leaf nodes are marked by parent node and
240+
# their relative position (left or right child)
236241
if (parent_i, left) in false_roots:
237242
false_roots[(parent_i, left)][0] += 1
238243
X_copy[(parent_i, left)].append(X[i])
@@ -244,16 +249,20 @@ cdef class DepthFirstTreeBuilder(TreeBuilder):
244249

245250
X_list = []
246251
y_list = []
252+
253+
# reorder the samples according to parent node IDs
247254
for key, value in reversed(sorted(X_copy.items())):
248255
X_list = X_list + value
249256
y_list = y_list + y_copy[key]
250257
cdef object X_new = np.array(X_list)
251258
cdef cnp.ndarray y_new = np.array(y_list)
252259

260+
# initialize the splitter using sorted samples
253261
cdef Splitter splitter = self.splitter
254262
splitter.init(X_new, y_new, sample_weight, missing_values_in_feature_mask)
255263

256-
self.initial_roots = false_roots
264+
# convert dict to numpy array and store value
265+
self.initial_roots = np.array(list(false_roots.items()))
257266

258267
cpdef build(
259268
self,
@@ -275,11 +284,13 @@ cdef class DepthFirstTreeBuilder(TreeBuilder):
275284
cdef double min_weight_leaf = self.min_weight_leaf
276285
cdef SIZE_t min_samples_split = self.min_samples_split
277286
cdef double min_impurity_decrease = self.min_impurity_decrease
287+
cdef unsigned char store_leaf_values = self.store_leaf_values
288+
cdef cnp.ndarray initial_roots = self.initial_roots
278289

279290
# Initial capacity
280291
cdef int init_capacity
281292
cdef bint first = 0
282-
if self.initial_roots is None:
293+
if initial_roots is None:
283294
# Recursive partition (without actual recursion)
284295
splitter.init(X, y, sample_weight, missing_values_in_feature_mask)
285296

@@ -290,6 +301,14 @@ cdef class DepthFirstTreeBuilder(TreeBuilder):
290301

291302
tree._resize(init_capacity)
292303
first = 1
304+
else:
305+
# convert numpy array back to dict
306+
false_roots = {}
307+
for key_value_pair in initial_roots:
308+
false_roots[tuple(key_value_pair[0])] = key_value_pair[1]
309+
310+
# reset the root array
311+
self.initial_roots = None
293312

294313
cdef SIZE_t start = 0
295314
cdef SIZE_t end = 0
@@ -318,7 +337,7 @@ cdef class DepthFirstTreeBuilder(TreeBuilder):
318337

319338
if not first:
320339
# push reached leaf nodes onto stack
321-
for key, value in reversed(sorted(self.initial_roots.items())):
340+
for key, value in reversed(sorted(false_roots.items())):
322341
end += value[0]
323342
update_stack.push({
324343
"start": start,
@@ -332,9 +351,6 @@ cdef class DepthFirstTreeBuilder(TreeBuilder):
332351
"upper_bound": INFINITY,
333352
})
334353
start += value[0]
335-
if rc == -1:
336-
# got return code -1 - out-of-memory
337-
raise MemoryError()
338354
else:
339355
# push root node onto stack
340356
builder_stack.push({
@@ -348,9 +364,6 @@ cdef class DepthFirstTreeBuilder(TreeBuilder):
348364
"lower_bound": -INFINITY,
349365
"upper_bound": INFINITY,
350366
})
351-
if rc == -1:
352-
# got return code -1 - out-of-memory
353-
raise MemoryError()
354367

355368
with nogil:
356369
while not update_stack.empty():
@@ -398,10 +411,10 @@ cdef class DepthFirstTreeBuilder(TreeBuilder):
398411
(split.improvement + EPSILON <
399412
min_impurity_decrease))
400413

401-
node_id = tree._update_node(parent, is_left, is_leaf,
402-
split_ptr, impurity, n_node_samples,
403-
weighted_n_node_samples,
404-
split.missing_go_to_left)
414+
node_id = tree._update_node(parent, is_left, is_leaf,
415+
split_ptr, impurity, n_node_samples,
416+
weighted_n_node_samples,
417+
split.missing_go_to_left)
405418

406419
if node_id == INTPTR_MAX:
407420
rc = -1
@@ -471,7 +484,7 @@ cdef class DepthFirstTreeBuilder(TreeBuilder):
471484
"lower_bound": left_child_min,
472485
"upper_bound": left_child_max,
473486
})
474-
elif self.store_leaf_values and is_leaf:
487+
elif store_leaf_values and is_leaf:
475488
# copy leaf values to leaf_values array
476489
splitter.node_samples(tree.value_samples[node_id])
477490

@@ -599,7 +612,7 @@ cdef class DepthFirstTreeBuilder(TreeBuilder):
599612
"lower_bound": left_child_min,
600613
"upper_bound": left_child_max,
601614
})
602-
elif self.store_leaf_values and is_leaf:
615+
elif store_leaf_values and is_leaf:
603616
# copy leaf values to leaf_values array
604617
splitter.node_samples(tree.value_samples[node_id])
605618

@@ -618,8 +631,6 @@ cdef class DepthFirstTreeBuilder(TreeBuilder):
618631
if rc == -1:
619632
raise MemoryError()
620633

621-
self.initial_roots = None
622-
623634
# Best first builder ----------------------------------------------------------
624635
cdef struct FrontierRecord:
625636
# Record of information of a Node, the frontier for a split. Those records are
@@ -712,6 +723,7 @@ cdef class BestFirstTreeBuilder(TreeBuilder):
712723
# Parameters
713724
cdef Splitter splitter = self.splitter
714725
cdef SIZE_t max_leaf_nodes = self.max_leaf_nodes
726+
cdef unsigned char store_leaf_values = self.store_leaf_values
715727

716728
# Recursive partition (without actual recursion)
717729
splitter.init(X, y, sample_weight, missing_values_in_feature_mask)
@@ -770,7 +782,7 @@ cdef class BestFirstTreeBuilder(TreeBuilder):
770782
node.feature = _TREE_UNDEFINED
771783
node.threshold = _TREE_UNDEFINED
772784

773-
if self.store_leaf_values:
785+
if store_leaf_values:
774786
# copy leaf values to leaf_values array
775787
splitter.node_samples(tree.value_samples[record.node_id])
776788
else:

0 commit comments

Comments
 (0)