16
16
17
17
from sklearn .utils ._openmp_helpers import _openmp_effective_n_threads
18
18
19
- from ...utils .arrayfuncs import sum_parallel
20
19
from ._bitset import set_raw_bitset_from_binned_bitset
21
20
from .common import (
22
21
PREDICTOR_RECORD_DTYPE ,
@@ -353,7 +352,7 @@ def __init__(
353
352
self .total_compute_hist_time = 0.0 # time spent computing histograms
354
353
self .total_apply_split_time = 0.0 # time spent splitting nodes
355
354
self .n_categorical_splits = 0
356
- self ._initialize_root (gradients , hessians )
355
+ self ._initialize_root ()
357
356
self .n_nodes = 1
358
357
359
358
def _validate_parameters (
@@ -401,15 +400,38 @@ def _apply_shrinkage(self):
401
400
for leaf in self .finalized_leaves :
402
401
leaf .value *= self .shrinkage
403
402
404
- def _initialize_root (self , gradients , hessians ):
403
+ def _initialize_root (self ):
405
404
"""Initialize root node and finalize it if needed."""
405
+ tic = time ()
406
+ if self .interaction_cst is not None :
407
+ allowed_features = set ().union (* self .interaction_cst )
408
+ allowed_features = np .fromiter (
409
+ allowed_features , dtype = np .uint32 , count = len (allowed_features )
410
+ )
411
+ arbitrary_feature = allowed_features [0 ]
412
+ else :
413
+ allowed_features = None
414
+ arbitrary_feature = 0
415
+
416
+ # TreeNode init needs the total sum of gradients and hessians. Therefore, we
417
+ # first compute the histograms and then compute the total grad/hess on an
418
+ # arbitrary feature histogram. This way we replace a loop over n_samples by a
419
+ # loop over n_bins.
420
+ histograms = self .histogram_builder .compute_histograms_brute (
421
+ self .splitter .partition , # =self.root.sample_indices
422
+ allowed_features ,
423
+ )
424
+ self .total_compute_hist_time += time () - tic
425
+
426
+ tic = time ()
406
427
n_samples = self .X_binned .shape [0 ]
407
428
depth = 0
408
- sum_gradients = sum_parallel (gradients , self .n_threads )
429
+ histogram_array = np .asarray (histograms [arbitrary_feature ])
430
+ sum_gradients = histogram_array ["sum_gradients" ].sum ()
409
431
if self .histogram_builder .hessians_are_constant :
410
- sum_hessians = hessians [0 ] * n_samples
432
+ sum_hessians = self . histogram_builder . hessians [0 ] * n_samples
411
433
else :
412
- sum_hessians = sum_parallel ( hessians , self . n_threads )
434
+ sum_hessians = histogram_array [ "sum_hessians" ]. sum ( )
413
435
self .root = TreeNode (
414
436
depth = depth ,
415
437
sample_indices = self .splitter .partition ,
@@ -430,18 +452,10 @@ def _initialize_root(self, gradients, hessians):
430
452
431
453
if self .interaction_cst is not None :
432
454
self .root .interaction_cst_indices = range (len (self .interaction_cst ))
433
- allowed_features = set ().union (* self .interaction_cst )
434
- self .root .allowed_features = np .fromiter (
435
- allowed_features , dtype = np .uint32 , count = len (allowed_features )
436
- )
455
+ self .root .allowed_features = allowed_features
437
456
438
- tic = time ()
439
- self .root .histograms = self .histogram_builder .compute_histograms_brute (
440
- self .root .sample_indices , self .root .allowed_features
441
- )
442
- self .total_compute_hist_time += time () - tic
457
+ self .root .histograms = histograms
443
458
444
- tic = time ()
445
459
self ._compute_best_split_and_push (self .root )
446
460
self .total_find_split_time += time () - tic
447
461
0 commit comments