Skip to content

Commit ada9947

Browse files
authored
ENH improve init_root of HGBT TreeGrower (scikit-learn#30875)
1 parent e13e280 commit ada9947

File tree

3 files changed

+31
-33
lines changed

3 files changed

+31
-33
lines changed

sklearn/ensemble/_hist_gradient_boosting/grower.py

Lines changed: 30 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616

1717
from sklearn.utils._openmp_helpers import _openmp_effective_n_threads
1818

19-
from ...utils.arrayfuncs import sum_parallel
2019
from ._bitset import set_raw_bitset_from_binned_bitset
2120
from .common import (
2221
PREDICTOR_RECORD_DTYPE,
@@ -353,7 +352,7 @@ def __init__(
353352
self.total_compute_hist_time = 0.0 # time spent computing histograms
354353
self.total_apply_split_time = 0.0 # time spent splitting nodes
355354
self.n_categorical_splits = 0
356-
self._initialize_root(gradients, hessians)
355+
self._initialize_root()
357356
self.n_nodes = 1
358357

359358
def _validate_parameters(
@@ -401,15 +400,38 @@ def _apply_shrinkage(self):
401400
for leaf in self.finalized_leaves:
402401
leaf.value *= self.shrinkage
403402

404-
def _initialize_root(self, gradients, hessians):
403+
def _initialize_root(self):
405404
"""Initialize root node and finalize it if needed."""
405+
tic = time()
406+
if self.interaction_cst is not None:
407+
allowed_features = set().union(*self.interaction_cst)
408+
allowed_features = np.fromiter(
409+
allowed_features, dtype=np.uint32, count=len(allowed_features)
410+
)
411+
arbitrary_feature = allowed_features[0]
412+
else:
413+
allowed_features = None
414+
arbitrary_feature = 0
415+
416+
# TreeNode init needs the total sum of gradients and hessians. Therefore, we
417+
# first compute the histograms and then compute the total grad/hess on an
418+
# arbitrary feature histogram. This way we replace a loop over n_samples by a
419+
# loop over n_bins.
420+
histograms = self.histogram_builder.compute_histograms_brute(
421+
self.splitter.partition, # =self.root.sample_indices
422+
allowed_features,
423+
)
424+
self.total_compute_hist_time += time() - tic
425+
426+
tic = time()
406427
n_samples = self.X_binned.shape[0]
407428
depth = 0
408-
sum_gradients = sum_parallel(gradients, self.n_threads)
429+
histogram_array = np.asarray(histograms[arbitrary_feature])
430+
sum_gradients = histogram_array["sum_gradients"].sum()
409431
if self.histogram_builder.hessians_are_constant:
410-
sum_hessians = hessians[0] * n_samples
432+
sum_hessians = self.histogram_builder.hessians[0] * n_samples
411433
else:
412-
sum_hessians = sum_parallel(hessians, self.n_threads)
434+
sum_hessians = histogram_array["sum_hessians"].sum()
413435
self.root = TreeNode(
414436
depth=depth,
415437
sample_indices=self.splitter.partition,
@@ -430,18 +452,10 @@ def _initialize_root(self, gradients, hessians):
430452

431453
if self.interaction_cst is not None:
432454
self.root.interaction_cst_indices = range(len(self.interaction_cst))
433-
allowed_features = set().union(*self.interaction_cst)
434-
self.root.allowed_features = np.fromiter(
435-
allowed_features, dtype=np.uint32, count=len(allowed_features)
436-
)
455+
self.root.allowed_features = allowed_features
437456

438-
tic = time()
439-
self.root.histograms = self.histogram_builder.compute_histograms_brute(
440-
self.root.sample_indices, self.root.allowed_features
441-
)
442-
self.total_compute_hist_time += time() - tic
457+
self.root.histograms = histograms
443458

444-
tic = time()
445459
self._compute_best_split_and_push(self.root)
446460
self.total_find_split_time += time() - tic
447461

sklearn/utils/arrayfuncs.pyx

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,10 @@
11
"""A small collection of auxiliary functions that operate on arrays."""
22

33
from cython cimport floating
4-
from cython.parallel cimport prange
54
from libc.math cimport fabs
65
from libc.float cimport DBL_MAX, FLT_MAX
76

87
from ._cython_blas cimport _copy, _rotg, _rot
9-
from ._typedefs cimport float64_t
108

119

1210
ctypedef fused real_numeric:
@@ -118,17 +116,3 @@ def cholesky_delete(floating[:, :] L, int go_out):
118116
L1 += m
119117

120118
_rot(n - i - 2, L1 + i, m, L1 + i + 1, m, c, s)
121-
122-
123-
def sum_parallel(const floating [:] array, int n_threads):
124-
"""Parallel sum, always using float64 internally."""
125-
cdef:
126-
float64_t out = 0.
127-
int i = 0
128-
129-
for i in prange(
130-
array.shape[0], schedule='static', nogil=True, num_threads=n_threads
131-
):
132-
out += array[i]
133-
134-
return out

sklearn/utils/meson.build

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ utils_extension_metadata = {
1818
'sparsefuncs_fast':
1919
{'sources': ['sparsefuncs_fast.pyx']},
2020
'_cython_blas': {'sources': ['_cython_blas.pyx']},
21-
'arrayfuncs': {'sources': ['arrayfuncs.pyx'], 'dependencies': [openmp_dep]},
21+
'arrayfuncs': {'sources': ['arrayfuncs.pyx']},
2222
'murmurhash': {
2323
'sources': ['murmurhash.pyx', 'src' / 'MurmurHash3.cpp'],
2424
},

0 commit comments

Comments
 (0)