Skip to content

Commit 49526f0

Browse files
committed
Fix classes and criterion
Signed-off-by: Adam Li <adam2392@gmail.com>
1 parent 5a2ac9a commit 49526f0

File tree

3 files changed

+141
-5
lines changed

3 files changed

+141
-5
lines changed

sklearn/tree/_classes.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -713,6 +713,73 @@ def feature_importances_(self):
713713

714714
return self.tree_.compute_feature_importances()
715715

716+
def _get_y_for_leaves(self, X, sample_weight=None):
717+
n_samples = X.shape[0]
718+
719+
# get the predictions
720+
X_leaves = self.apply(X)
721+
722+
bootstrap_indices = np.empty(shape, dtype=np.int64)
723+
for i, estimator in enumerate(self.estimators_):
724+
# Get bootstrap indices.
725+
if self.bootstrap:
726+
n_samples_bootstrap = _get_n_samples_bootstrap(n_samples, self.max_samples)
727+
bootstrap_indices[:, i] = _generate_sample_indices(
728+
estimator.random_state, n_samples, n_samples_bootstrap
729+
)
730+
else:
731+
bootstrap_indices[:, i] = np.arange(n_samples)
732+
733+
# Get predictions on bootstrap indices.
734+
X_leaves[:, i] = X_leaves[bootstrap_indices[:, i], i]
735+
736+
if sorter is not None:
737+
# Reassign bootstrap indices to account for target sorting.
738+
bootstrap_indices = np.argsort(sorter)[bootstrap_indices]
739+
740+
bootstrap_indices += 1 # for sparse matrix (0s as empty)
741+
742+
# Get the maximum number of nodes (internal + leaves) across trees.
743+
# Get the maximum number of samples per leaf across trees (if needed).
744+
max_node_count = 0
745+
max_samples_leaf = 0 if not leaf_subsample else max_samples_leaf
746+
for i, estimator in enumerate(self.estimators_):
747+
node_count = estimator.tree_.node_count
748+
if node_count > max_node_count:
749+
max_node_count = node_count
750+
if not leaf_subsample:
751+
sample_count = np.max(np.bincount(X_leaves[:, i]))
752+
if sample_count > max_samples_leaf:
753+
max_samples_leaf = sample_count
754+
755+
# Initialize NumPy array (more efficient serialization than dict/list).
756+
shape = (self.n_estimators, max_node_count, max_samples_leaf)
757+
y_train_leaves = np.zeros(shape, dtype=np.int64)
758+
759+
for i, estimator in enumerate(self.estimators_):
760+
# Group training indices by leaf node.
761+
leaf_indices, leaf_values_list = _group_by_value(X_leaves[:, i])
762+
763+
if leaf_subsample:
764+
random.seed(estimator.random_state)
765+
766+
# Map each leaf node to its list of training indices.
767+
for leaf_idx, leaf_values in zip(leaf_indices, leaf_values_list):
768+
y_indices = bootstrap_indices[:, i][leaf_values]
769+
770+
if sample_weight is not None:
771+
y_indices = y_indices[sample_weight[y_indices - 1] > 0]
772+
773+
# Subsample leaf training indices (without replacement).
774+
if leaf_subsample and max_samples_leaf < len(y_indices):
775+
if not isinstance(y_indices, list):
776+
y_indices = list(y_indices)
777+
y_indices = random.sample(y_indices, max_samples_leaf)
778+
779+
y_train_leaves[i, leaf_idx, : len(y_indices)] = y_indices
780+
781+
return y_train_leaves
782+
716783

717784
# =============================================================================
718785
# Public estimators

sklearn/tree/_criterion.pxd

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111

1212
# See _criterion.pyx for implementation details.
1313

14+
# from libcpp.vector cimport vector
15+
1416
from ._tree cimport DTYPE_t # Type of X
1517
from ._tree cimport DOUBLE_t # Type of y, sample_weight
1618
from ._tree cimport SIZE_t # Type for indices and counters
@@ -19,7 +21,7 @@ from ._tree cimport UINT32_t # Unsigned 32 bit integer
1921

2022

2123
cdef class BaseCriterion:
22-
"""Abstract interface for criterion."""
24+
"""Abstract interface for criterion."""
2325

2426
# Internal structures
2527
cdef const DOUBLE_t[:] sample_weight # Sample weights
@@ -70,13 +72,18 @@ cdef class BaseCriterion:
7072
SIZE_t end
7173
) noexcept nogil
7274

75+
# cdef void node_samples(
76+
# self,
77+
# vector[vector[DOUBLE_t]]* dest
78+
# ) noexcept nogil
79+
7380
cdef class Criterion(BaseCriterion):
7481
"""Abstract interface for supervised impurity criteria."""
7582

7683
cdef const DOUBLE_t[:, ::1] y # Values of y
7784
cdef SIZE_t n_missing # Number of missing values for the feature being evaluated
7885
cdef bint missing_go_to_left # Whether missing values go to the left node
79-
86+
8087
cdef int init(
8188
self,
8289
const DOUBLE_t[:, ::1] y,

0 commit comments

Comments
 (0)