@@ -132,13 +132,11 @@ XGBOOST_DEV_INLINE void AtomicAddGpairGlobal(xgboost::GradientPairInt64* dest,
132
132
auto g = gpair.GetQuantisedGrad ();
133
133
auto h = gpair.GetQuantisedHess ();
134
134
135
- atomicAdd (dst_ptr,
136
- *reinterpret_cast <uint64_t *>(&g));
137
- atomicAdd (dst_ptr + 1 ,
138
- *reinterpret_cast <uint64_t *>(&h));
135
+ atomicAdd (dst_ptr, *reinterpret_cast <uint64_t *>(&g));
136
+ atomicAdd (dst_ptr + 1 , *reinterpret_cast <uint64_t *>(&h));
139
137
}
140
138
141
- template <bool kCompressed , int kBlockThreads , int kItemsPerThread >
139
+ template <bool kCompressed , bool kDense , int kBlockThreads , int kItemsPerThread >
142
140
class HistogramAgent {
143
141
int constexpr static kItemsPerTile = kBlockThreads * kItemsPerThread ;
144
142
@@ -154,6 +152,8 @@ class HistogramAgent {
154
152
const bst_idx_t n_elements_;
155
153
const GradientQuantiser& rounding_;
156
154
155
+ static_assert (kCompressed >= kDense );
156
+
157
157
public:
158
158
__device__ HistogramAgent (GradientPairInt64* smem_arr,
159
159
GradientPairInt64* __restrict__ d_node_hist, const FeatureGroup& group,
@@ -176,7 +176,7 @@ class HistogramAgent {
176
176
Idx ridx = d_ridx_[idx / feature_stride_];
177
177
auto fidx = FeatIdx (group_, idx, feature_stride_);
178
178
bst_bin_t compressed_bin = matrix_.gidx_iter [IterIdx (matrix_, ridx, fidx)];
179
- if (compressed_bin != matrix_.NullValue ()) {
179
+ if (kDense || compressed_bin != matrix_.NullValue ()) {
180
180
// The matrix is compressed with feature-local bins.
181
181
if (kCompressed ) {
182
182
compressed_bin += this ->matrix_ .feature_segments [fidx];
@@ -211,18 +211,20 @@ class HistogramAgent {
211
211
gpair[i] = d_gpair_[ridx[i]];
212
212
auto fidx = FeatIdx (group_, idx[i], feature_stride_);
213
213
gidx[i] = matrix_.gidx_iter [IterIdx (matrix_, ridx[i], fidx)];
214
- if (gidx[i] != matrix_.NullValue ()) {
215
- if (kCompressed ) {
214
+ if (kDense || gidx[i] != matrix_.NullValue ()) {
215
+ if constexpr (kCompressed ) {
216
216
gidx[i] += matrix_.feature_segments [fidx];
217
217
}
218
218
} else {
219
- gidx[i] = -1 ; // missing
219
+ // Use -1 to denote missing. Since we need to add the beginning bin to gidx, the
220
+ // result might equal to the `NullValue`.
221
+ gidx[i] = -1 ;
220
222
}
221
223
}
222
224
#pragma unroll
223
225
for (int i = 0 ; i < kItemsPerThread ; i++) {
224
226
// Avoid atomic add if it's a null value.
225
- if (gidx[i] != -1 ) {
227
+ if (kDense || gidx[i] != -1 ) {
226
228
auto adjusted = rounding_.ToFixedPoint (gpair[i]);
227
229
AtomicAddGpairShared (smem_arr_ + gidx[i] - group_.start_bin , adjusted);
228
230
}
@@ -262,7 +264,8 @@ class HistogramAgent {
262
264
}
263
265
};
264
266
265
- template <bool kIsDense , bool use_shared_memory_histograms, int kBlockThreads , int kItemsPerThread >
267
+ template <bool kCompressed , bool kDense , bool use_shared_memory_histograms, int kBlockThreads ,
268
+ int kItemsPerThread >
266
269
__global__ void __launch_bounds__ (kBlockThreads )
267
270
SharedMemHistKernel(const EllpackDeviceAccessor matrix,
268
271
const FeatureGroupsAccessor feature_groups,
@@ -273,7 +276,7 @@ __global__ void __launch_bounds__(kBlockThreads)
273
276
extern __shared__ char smem[];
274
277
const FeatureGroup group = feature_groups[blockIdx .y ];
275
278
auto smem_arr = reinterpret_cast <GradientPairInt64*>(smem);
276
- auto agent = HistogramAgent<kIsDense , kBlockThreads , kItemsPerThread >(
279
+ auto agent = HistogramAgent<kCompressed , kDense , kBlockThreads , kItemsPerThread >(
277
280
smem_arr, d_node_hist, group, matrix, d_ridx, rounding, d_gpair);
278
281
if (use_shared_memory_histograms) {
279
282
agent.BuildHistogramWithShared ();
@@ -289,30 +292,41 @@ constexpr std::int32_t ItemsPerTile() { return kBlockThreads * kItemsPerThread;
289
292
} // namespace
290
293
291
294
// Use auto deduction guide to workaround compiler error.
292
- template <auto GlobalDense = SharedMemHistKernel<true , false , kBlockThreads , kItemsPerThread >,
293
- auto Global = SharedMemHistKernel<false , false , kBlockThreads , kItemsPerThread >,
294
- auto SharedDense = SharedMemHistKernel<true , true , kBlockThreads , kItemsPerThread >,
295
- auto Shared = SharedMemHistKernel<false , true , kBlockThreads , kItemsPerThread >>
295
+ template <auto GlobalCompr =
296
+ SharedMemHistKernel<true , false , false , kBlockThreads , kItemsPerThread >,
297
+ auto Global = SharedMemHistKernel<false , false , false , kBlockThreads , kItemsPerThread >,
298
+ auto SharedCompr = SharedMemHistKernel<true , false , true , kBlockThreads , kItemsPerThread >,
299
+ auto Shared = SharedMemHistKernel<false , false , true , kBlockThreads , kItemsPerThread >,
300
+ auto GlobalDense = SharedMemHistKernel<true , true , false , kBlockThreads , kItemsPerThread >,
301
+ auto SharedDense = SharedMemHistKernel<true , true , true , kBlockThreads , kItemsPerThread >>
296
302
struct HistogramKernel {
297
303
enum KernelType : std::size_t {
298
- kGlobalDense = 0 ,
304
+ kGlobalCompr = 0 ,
299
305
kGlobal = 1 ,
300
- kSharedDense = 2 ,
306
+ kSharedCompr = 2 ,
301
307
kShared = 3 ,
308
+ kGlobalDense = 4 ,
309
+ kSharedDense = 5 ,
302
310
};
303
311
// Kernel for working with dense Ellpack using the global memory.
304
- decltype (GlobalDense) global_dense_kernel {
305
- SharedMemHistKernel<true , false , kBlockThreads , kItemsPerThread >};
312
+ decltype (GlobalCompr) global_compr_kernel {
313
+ SharedMemHistKernel<true , false , false , kBlockThreads , kItemsPerThread >};
306
314
// Kernel for working with sparse Ellpack using the global memory.
307
- decltype (Global) global_kernel{SharedMemHistKernel<false , false , kBlockThreads , kItemsPerThread >};
315
+ decltype (Global) global_kernel{
316
+ SharedMemHistKernel<false , false , false , kBlockThreads , kItemsPerThread >};
308
317
// Kernel for working with dense Ellpack using the shared memory.
309
- decltype (SharedDense) shared_dense_kernel {
310
- SharedMemHistKernel<true , true , kBlockThreads , kItemsPerThread >};
318
+ decltype (SharedCompr) shared_compr_kernel {
319
+ SharedMemHistKernel<true , false , true , kBlockThreads , kItemsPerThread >};
311
320
// Kernel for working with sparse Ellpack using the shared memory.
312
- decltype (Shared) shared_kernel{SharedMemHistKernel<false , true , kBlockThreads , kItemsPerThread >};
321
+ decltype (Shared) shared_kernel{
322
+ SharedMemHistKernel<false , false , true , kBlockThreads , kItemsPerThread >};
323
+ decltype (GlobalDense) global_dense_kernel{
324
+ SharedMemHistKernel<true , true , false , kBlockThreads , kItemsPerThread >};
325
+ decltype (SharedDense) shared_dense_kernel{
326
+ SharedMemHistKernel<true , true , true , kBlockThreads , kItemsPerThread >};
313
327
314
328
bool shared{false };
315
- std::array<std::uint32_t , 4 > grid_sizes{0 , 0 , 0 , 0 };
329
+ std::array<std::uint32_t , 6 > grid_sizes{0 , 0 , 0 , 0 , 0 , 0 };
316
330
std::size_t smem_size{0 };
317
331
bool const force_global;
318
332
@@ -347,9 +361,11 @@ struct HistogramKernel {
347
361
this ->grid_sizes [static_cast <std::size_t >(k)] = n_blocks_per_mp * n_mps;
348
362
};
349
363
// Initialize all kernel instantiations
350
- std::array kernel_types{kGlobalDense , kGlobal , kSharedDense , kShared };
364
+ std::array kernel_types{kGlobalCompr , kGlobal , kSharedCompr ,
365
+ kShared , kGlobalDense , kSharedDense };
351
366
std::int32_t k = 0 ;
352
- for (auto & kernel : {global_dense_kernel, global_kernel, shared_dense_kernel, shared_kernel}) {
367
+ for (auto & kernel : {global_compr_kernel, global_kernel, shared_compr_kernel, shared_kernel,
368
+ global_dense_kernel, shared_dense_kernel}) {
353
369
init (kernel, kernel_types[k]);
354
370
++k;
355
371
}
@@ -397,19 +413,24 @@ class DeviceHistogramBuilderImpl {
397
413
using K = HistogramKernel<>::KernelType;
398
414
if (!this ->kernel_ ->shared ) { // Use global memory
399
415
CHECK_EQ (this ->kernel_ ->smem_size , 0 );
400
- if (matrix.IsDenseCompressed ()) {
401
- // Dense must use shared memory except for testing.
416
+ if (matrix.IsDense ()) {
402
417
CHECK (this ->kernel_ ->force_global );
403
418
launcher (this ->kernel_ ->global_dense_kernel , this ->kernel_ ->grid_sizes [K::kGlobalDense ]);
419
+ } else if (matrix.IsDenseCompressed ()) {
420
+ // Dense must use shared memory except for testing.
421
+ CHECK (this ->kernel_ ->force_global );
422
+ launcher (this ->kernel_ ->global_compr_kernel , this ->kernel_ ->grid_sizes [K::kGlobalCompr ]);
404
423
} else {
405
424
// Sparse
406
425
launcher (this ->kernel_ ->global_kernel , this ->kernel_ ->grid_sizes [K::kGlobal ]);
407
426
}
408
427
} else { // Use shared memory
409
428
CHECK_NE (this ->kernel_ ->smem_size , 0 );
410
- if (matrix.IsDenseCompressed ()) {
411
- // Dense
429
+ if (matrix.IsDense ()) {
412
430
launcher (this ->kernel_ ->shared_dense_kernel , this ->kernel_ ->grid_sizes [K::kSharedDense ]);
431
+ } else if (matrix.IsDenseCompressed ()) {
432
+ // Dense
433
+ launcher (this ->kernel_ ->shared_compr_kernel , this ->kernel_ ->grid_sizes [K::kSharedCompr ]);
413
434
} else {
414
435
// Sparse
415
436
launcher (this ->kernel_ ->shared_kernel , this ->kernel_ ->grid_sizes [K::kShared ]);
0 commit comments