Skip to content

Commit 9b07f2a

Browse files
committed
Add leaf storage ability
Signed-off-by: Adam Li <adam2392@gmail.com>
1 parent 2105949 commit 9b07f2a

File tree

7 files changed

+122
-123
lines changed

7 files changed

+122
-123
lines changed

sklearn/tree/_classes.py

Lines changed: 0 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -713,73 +713,6 @@ def feature_importances_(self):
713713

714714
return self.tree_.compute_feature_importances()
715715

716-
def _get_y_for_leaves(self, X, sample_weight=None):
717-
n_samples = X.shape[0]
718-
719-
# get the predictions
720-
X_leaves = self.apply(X)
721-
722-
bootstrap_indices = np.empty(shape, dtype=np.int64)
723-
for i, estimator in enumerate(self.estimators_):
724-
# Get bootstrap indices.
725-
if self.bootstrap:
726-
n_samples_bootstrap = _get_n_samples_bootstrap(n_samples, self.max_samples)
727-
bootstrap_indices[:, i] = _generate_sample_indices(
728-
estimator.random_state, n_samples, n_samples_bootstrap
729-
)
730-
else:
731-
bootstrap_indices[:, i] = np.arange(n_samples)
732-
733-
# Get predictions on bootstrap indices.
734-
X_leaves[:, i] = X_leaves[bootstrap_indices[:, i], i]
735-
736-
if sorter is not None:
737-
# Reassign bootstrap indices to account for target sorting.
738-
bootstrap_indices = np.argsort(sorter)[bootstrap_indices]
739-
740-
bootstrap_indices += 1 # for sparse matrix (0s as empty)
741-
742-
# Get the maximum number of nodes (internal + leaves) across trees.
743-
# Get the maximum number of samples per leaf across trees (if needed).
744-
max_node_count = 0
745-
max_samples_leaf = 0 if not leaf_subsample else max_samples_leaf
746-
for i, estimator in enumerate(self.estimators_):
747-
node_count = estimator.tree_.node_count
748-
if node_count > max_node_count:
749-
max_node_count = node_count
750-
if not leaf_subsample:
751-
sample_count = np.max(np.bincount(X_leaves[:, i]))
752-
if sample_count > max_samples_leaf:
753-
max_samples_leaf = sample_count
754-
755-
# Initialize NumPy array (more efficient serialization than dict/list).
756-
shape = (self.n_estimators, max_node_count, max_samples_leaf)
757-
y_train_leaves = np.zeros(shape, dtype=np.int64)
758-
759-
for i, estimator in enumerate(self.estimators_):
760-
# Group training indices by leaf node.
761-
leaf_indices, leaf_values_list = _group_by_value(X_leaves[:, i])
762-
763-
if leaf_subsample:
764-
random.seed(estimator.random_state)
765-
766-
# Map each leaf node to its list of training indices.
767-
for leaf_idx, leaf_values in zip(leaf_indices, leaf_values_list):
768-
y_indices = bootstrap_indices[:, i][leaf_values]
769-
770-
if sample_weight is not None:
771-
y_indices = y_indices[sample_weight[y_indices - 1] > 0]
772-
773-
# Subsample leaf training indices (without replacement).
774-
if leaf_subsample and max_samples_leaf < len(y_indices):
775-
if not isinstance(y_indices, list):
776-
y_indices = list(y_indices)
777-
y_indices = random.sample(y_indices, max_samples_leaf)
778-
779-
y_train_leaves[i, leaf_idx, : len(y_indices)] = y_indices
780-
781-
return y_train_leaves
782-
783716

784717
# =============================================================================
785718
# Public estimators

sklearn/tree/_criterion.pxd

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -72,10 +72,6 @@ cdef class BaseCriterion:
7272
SIZE_t end
7373
) noexcept nogil
7474

75-
# cdef void node_samples(
76-
# self,
77-
# vector[vector[DOUBLE_t]]* dest
78-
# ) noexcept nogil
7975

8076
cdef class Criterion(BaseCriterion):
8177
"""Abstract interface for supervised impurity criteria."""
@@ -94,6 +90,11 @@ cdef class Criterion(BaseCriterion):
9490
cdef void init_sum_missing(self)
9591
cdef void init_missing(self, SIZE_t n_missing) noexcept nogil
9692

93+
cdef void node_samples(
94+
self,
95+
vector[vector[DOUBLE_t]]* dest
96+
) noexcept nogil
97+
9798
cdef class ClassificationCriterion(Criterion):
9899
"""Abstract criterion for classification."""
99100

sklearn/tree/_criterion.pyx

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ cdef class BaseCriterion:
4646
in current node and in children nodes.
4747
4848
This object stores methods on how to calculate how good a split is using
49-
a set API.
49+
a set API.
5050
5151
Samples in the "current" node are stored in `samples[start:end]` which is
5252
partitioned around `pos` (an index in `start:end`) so that:
@@ -186,9 +186,9 @@ cdef class BaseCriterion:
186186
) noexcept nogil:
187187
"""Abstract method which will set sample pointers in the criterion.
188188
189-
The dataset array that we compute criteria on is assumed to consist of 'N'
190-
ordered samples or rows (i.e. sorted). Since we pass this by reference, we
191-
use sample pointers to move the start and end around to consider only a subset of data.
189+
The dataset array that we compute criteria on is assumed to consist of 'N'
190+
ordered samples or rows (i.e. sorted). Since we pass this by reference, we
191+
use sample pointers to move the start and end around to consider only a subset of data.
192192
This function should also update relevant statistics that the class uses to compute the final criterion.
193193
194194
Parameters
@@ -252,10 +252,28 @@ cdef class Criterion(BaseCriterion):
252252
Number of missing values for specific feature.
253253
"""
254254
pass
255-
255+
256256
cdef void init_sum_missing(self):
257257
"""Init sum_missing to hold sums for missing values."""
258258

259+
cdef void node_samples(
260+
self,
261+
vector[vector[DOUBLE_t]]* dest
262+
) noexcept nogil:
263+
cdef SIZE_t i, j
264+
265+
# Resize the destination vector of vectors
266+
dest.resize(self.n_node_samples)
267+
268+
# Loop over the samples
269+
for i in range(self.n_node_samples):
270+
# Get the index of the current sample
271+
j = self.sample_indices[self.start + i]
272+
273+
# Get the sample values for each output
274+
for k in range(self.n_outputs):
275+
dest[i][k].push_back(self.y[j, k])
276+
259277
cdef inline void _move_sums_classification(
260278
ClassificationCriterion criterion,
261279
double[:, ::1] sum_1,

sklearn/tree/_splitter.pxd

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
# License: BSD 3 clause
1111

1212
# See _splitter.pyx for details.
13+
from libcpp.vector cimport vector
1314

1415
from ._criterion cimport BaseCriterion, Criterion
1516

@@ -106,6 +107,8 @@ cdef class Splitter(BaseSplitter):
106107
const unsigned char[::1] feature_has_missing,
107108
) except -1
108109

110+
cdef void node_samples(self, vector[vector[DOUBLE_t]]* dest) noexcept nogil
111+
109112
# Methods that allow modifications to stopping conditions
110113
cdef bint check_presplit_conditions(
111114
self,

sklearn/tree/_splitter.pyx

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -53,12 +53,12 @@ cdef inline void _init_split(SplitRecord* self, SIZE_t start_pos) noexcept nogil
5353
self.n_missing = 0
5454

5555
cdef class BaseSplitter:
56-
"""This is an abstract interface for splitters.
56+
"""This is an abstract interface for splitters.
5757
5858
For example, a tree model could be either supervisedly, or unsupervisedly computing splits on samples of
5959
covariates, labels, or both. Although scikit-learn currently only contains
6060
supervised tree methods, this class enables 3rd party packages to leverage
61-
scikit-learn's Cython code for splitting.
61+
scikit-learn's Cython code for splitting.
6262
6363
A splitter is usually used in conjunction with a criterion class, which explicitly handles
6464
computing the criteria, which we split on. The setting of that criterion class is handled
@@ -112,7 +112,7 @@ cdef class BaseSplitter:
112112

113113
cdef int pointer_size(self) noexcept nogil:
114114
"""Size of the pointer for split records.
115-
115+
116116
Overriding this function allows one to use different subclasses of
117117
`SplitRecord`.
118118
"""
@@ -156,7 +156,6 @@ cdef class Splitter(BaseSplitter):
156156
self.min_weight_leaf = min_weight_leaf
157157
self.random_state = random_state
158158

159-
160159
def __reduce__(self):
161160
return (type(self), (self.criterion,
162161
self.max_features,
@@ -281,6 +280,10 @@ cdef class Splitter(BaseSplitter):
281280

282281
self.criterion.node_value(dest)
283282

283+
cdef void node_samples(self, vector[vector[DOUBLE_t]]* dest) noexcept nogil:
284+
"""Copy the samples[start:end] into dest."""
285+
self.criterion.node_samples(dest)
286+
284287
cdef double node_impurity(self) noexcept nogil:
285288
"""Return the impurity of the current node."""
286289

@@ -293,15 +296,15 @@ cdef class Splitter(BaseSplitter):
293296
bint missing_go_to_left,
294297
) noexcept nogil:
295298
"""Check stopping conditions pre-split.
296-
299+
297300
This is typically a metric that is cheaply computed given the
298301
current proposed split, which is stored as a the `current_split`
299302
argument.
300303
"""
301304
cdef SIZE_t min_samples_leaf = self.min_samples_leaf
302305
cdef SIZE_t end_non_missing = self.end - n_missing
303306
cdef SIZE_t n_left, n_right
304-
307+
305308
if missing_go_to_left:
306309
n_left = current_split.pos - self.start + n_missing
307310
n_right = end_non_missing - current_split.pos
@@ -312,14 +315,14 @@ cdef class Splitter(BaseSplitter):
312315
# Reject if min_samples_leaf is not guaranteed
313316
if n_left < min_samples_leaf or n_right < min_samples_leaf:
314317
return 1
315-
318+
316319
return 0
317320

318321
cdef bint check_postsplit_conditions(
319322
self
320323
) noexcept nogil:
321324
"""Check stopping conditions after evaluating the split.
322-
325+
323326
This takes some metric that is stored in the Criterion
324327
object and checks against internal stop metrics.
325328
"""
@@ -329,10 +332,10 @@ cdef class Splitter(BaseSplitter):
329332
if ((self.criterion.weighted_n_left < min_weight_leaf) or
330333
(self.criterion.weighted_n_right < min_weight_leaf)):
331334
return 1
332-
335+
333336
return 0
334337

335-
338+
336339
cdef inline void shift_missing_values_to_left_if_required(
337340
SplitRecord* best,
338341
SIZE_t[::1] samples,
@@ -360,7 +363,7 @@ cdef inline void shift_missing_values_to_left_if_required(
360363
ctypedef fused Partitioner:
361364
DensePartitioner
362365
SparsePartitioner
363-
366+
364367
cdef inline int node_split_best(
365368
Splitter splitter,
366369
Partitioner partitioner,
@@ -504,9 +507,9 @@ cdef inline int node_split_best(
504507

505508
if p >= end_non_missing:
506509
continue
507-
510+
508511
current_split.pos = p
509-
512+
510513
# Reject if min_samples_leaf is not guaranteed
511514
if splitter.check_presplit_conditions(current_split, n_missing, missing_go_to_left) == 1:
512515
continue
@@ -740,8 +743,6 @@ cdef inline int node_split_random(
740743
cdef SIZE_t n_features = splitter.n_features
741744

742745
cdef SIZE_t max_features = splitter.max_features
743-
cdef SIZE_t min_samples_leaf = splitter.min_samples_leaf
744-
cdef double min_weight_leaf = splitter.min_weight_leaf
745746
cdef UINT32_t* random_state = &splitter.rand_r_state
746747

747748
cdef SplitRecord best_split, current_split

sklearn/tree/_tree.pxd

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ import numpy as np
1414
cimport numpy as cnp
1515

1616
from libcpp.vector cimport vector
17+
from libcpp.unordered_map cimport unordered_map
1718

1819
ctypedef cnp.npy_float32 DTYPE_t # Type of X
1920
ctypedef cnp.npy_float64 DOUBLE_t # Type of y, sample_weight
@@ -36,6 +37,7 @@ cdef struct Node:
3637
DOUBLE_t weighted_n_node_samples # Weighted number of samples at the node
3738
unsigned char missing_go_to_left # Whether features have missing values
3839

40+
3941
cdef class BaseTree:
4042
# Inner structures: values are stored separately from node structure,
4143
# since size is determined at runtime.
@@ -45,7 +47,14 @@ cdef class BaseTree:
4547
cdef Node* nodes # Array of nodes
4648

4749
cdef SIZE_t value_stride # The dimensionality of a vectorized output per sample
48-
cdef double* value # Array of values prediction values for each node
50+
cdef double* value # Array of values prediction values for each node
51+
52+
# Enables the use of tree to store distributions of the output to allow
53+
# arbitrary usage of the the leaves. This is used in the quantile
54+
# estimators for example.
55+
# for storing samples at each leaf node with leaf's node ID as the key and
56+
# the sample values as the value
57+
cdef unordered_map[SIZE_t, vector[vector[DOUBLE_t]]] value_samples
4958

5059
# Generic Methods: These are generic methods used by any tree.
5160
cdef int _resize(self, SIZE_t capacity) except -1 nogil
@@ -61,7 +70,7 @@ cdef class BaseTree:
6170
double weighted_n_node_samples,
6271
unsigned char missing_go_to_left
6372
) except -1 nogil
64-
73+
6574
# Python API methods: These are methods exposed to Python
6675
cpdef cnp.ndarray apply(self, object X)
6776
cdef cnp.ndarray _apply_dense(self, object X)
@@ -101,10 +110,10 @@ cdef class Tree(BaseTree):
101110
# The Supervised Tree object is a binary tree structure constructed by the
102111
# TreeBuilder. The tree structure is used for predictions and
103112
# feature importances.
104-
#
113+
#
105114
# Value of upstream properties:
106115
# - value_stride = n_outputs * max_n_classes
107-
# - value = (capacity, n_outputs, max_n_classes) array of values
116+
# - value = (capacity, n_outputs, max_n_classes) array of values
108117

109118
# Input/Output layout for supervised tree
110119
cdef public SIZE_t n_features # Number of features in X
@@ -137,6 +146,8 @@ cdef class TreeBuilder:
137146
cdef SIZE_t max_depth # Maximal tree depth
138147
cdef double min_impurity_decrease # Impurity threshold for early stopping
139148

149+
cdef unsigned char store_leaf_values # Whether to store leaf values
150+
140151
cpdef build(
141152
self,
142153
Tree tree,

0 commit comments

Comments
 (0)