Skip to content

Commit afe556a

Browse files
authored
Add CUDA stream pool. (#11458)
1 parent f786d37 commit afe556a

File tree

4 files changed

+73
-4
lines changed

4 files changed

+73
-4
lines changed

src/common/cuda_stream_pool.cuh

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
/**
2+
* Copyright 2025, XGBoost contributors
3+
*/
4+
#pragma once
5+
#include <atomic> // for atomic
6+
#include <cstddef> // for size_t
7+
#include <vector> // for vector
8+
9+
#include "device_helpers.cuh" // for CUDAStreamView, CUDAStream
10+
11+
namespace xgboost::curt {
12+
// rmm cuda_stream_pool
13+
class StreamPool {
14+
mutable std::atomic<std::size_t> next_{0};
15+
std::vector<dh::CUDAStream> stream_;
16+
17+
public:
18+
explicit StreamPool(std::size_t n) : stream_(n) {}
19+
~StreamPool() = default;
20+
StreamPool(StreamPool const& that) = delete;
21+
StreamPool& operator=(StreamPool const& that) = delete;
22+
23+
[[nodiscard]] dh::CUDAStreamView operator[](std::size_t i) const { return stream_[i].View(); }
24+
[[nodiscard]] dh::CUDAStreamView Next() const {
25+
return stream_[(next_++) % stream_.size()].View();
26+
}
27+
[[nodiscard]] std::size_t Size() const { return stream_.size(); }
28+
};
29+
} // namespace xgboost::curt

src/data/ellpack_page_source.cu

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
#include "../common/common.h" // for HumanMemUnit, safe_cuda
1212
#include "../common/cuda_rt_utils.h" // for SetDevice
13+
#include "../common/cuda_stream_pool.cuh" // for StreamPool
1314
#include "../common/device_helpers.cuh" // for CUDAStreamView, DefaultStream
1415
#include "../common/ref_resource_view.cuh" // for MakeFixedVecWithCudaMalloc
1516
#include "../common/resource.cuh" // for PrivateCudaMmapConstStream
@@ -25,11 +26,12 @@ namespace xgboost::data {
2526
/**
2627
* Cache
2728
*/
28-
EllpackMemCache::EllpackMemCache(EllpackCacheInfo cinfo)
29+
EllpackMemCache::EllpackMemCache(EllpackCacheInfo cinfo, std::int32_t n_workers)
2930
: cache_mapping{std::move(cinfo.cache_mapping)},
3031
buffer_bytes{std::move(cinfo.buffer_bytes)},
3132
buffer_rows{std::move(cinfo.buffer_rows)},
32-
cache_host_ratio{cinfo.cache_host_ratio} {
33+
cache_host_ratio{cinfo.cache_host_ratio},
34+
streams{std::make_unique<curt::StreamPool>(n_workers)} {
3335
CHECK_EQ(buffer_bytes.size(), buffer_rows.size());
3436
CHECK(!detail::HostRatioIsAuto(this->cache_host_ratio));
3537
CHECK_GE(this->cache_host_ratio, 0) << error::CacheHostRatioInvalid();
@@ -289,7 +291,8 @@ EllpackCacheStreamPolicy<S, F>::CreateWriter(StringView, std::uint32_t iter) {
289291
CHECK(!detail::HostRatioIsAuto(this->CacheInfo().cache_host_ratio));
290292
CHECK_GE(this->CacheInfo().cache_host_ratio, 0.0);
291293
CHECK_LE(this->CacheInfo().cache_host_ratio, 1.0);
292-
this->p_cache_ = std::make_unique<EllpackMemCache>(this->CacheInfo());
294+
constexpr std::int32_t kMaxGpuExtMemWorkers = 4;
295+
this->p_cache_ = std::make_unique<EllpackMemCache>(this->CacheInfo(), kMaxGpuExtMemWorkers);
293296
}
294297
auto fo = std::make_unique<EllpackHostCacheStream>(this->p_cache_);
295298
if (iter == 0) {

src/data/ellpack_page_source.h

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,10 @@
2323
#include "xgboost/data.h" // for BatchParam
2424
#include "xgboost/span.h" // for Span
2525

26+
namespace xgboost::curt {
27+
class StreamPool;
28+
}
29+
2630
namespace xgboost::data {
2731
struct EllpackCacheInfo {
2832
BatchParam param;
@@ -66,7 +70,9 @@ struct EllpackMemCache {
6670
std::vector<bst_idx_t> const buffer_rows;
6771
float const cache_host_ratio;
6872

69-
explicit EllpackMemCache(EllpackCacheInfo cinfo);
73+
std::unique_ptr<curt::StreamPool> streams;
74+
75+
explicit EllpackMemCache(EllpackCacheInfo cinfo, std::int32_t n_workers);
7076
~EllpackMemCache();
7177

7278
// The number of bytes of the entire cache.
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
/**
2+
* Copyright 2025, XGBoost contributors
3+
*/
4+
5+
#include <gtest/gtest.h>
6+
7+
#include <cstdint> // for int32_t
8+
#include <set> // for set
9+
10+
#include "../../../src/common/cuda_stream_pool.cuh"
11+
12+
namespace xgboost::curt {
13+
TEST(RtUtils, StreamPool) {
14+
auto n_streams = 16;
15+
auto pool = std::make_unique<StreamPool>(n_streams);
16+
std::set<cudaStream_t> hdls;
17+
18+
for (std::int32_t i = 0; i < n_streams; ++i) {
19+
hdls.insert(cudaStream_t{pool->Next()});
20+
}
21+
22+
ASSERT_EQ(hdls.size(), n_streams);
23+
ASSERT_EQ(hdls.size(), pool->Size());
24+
25+
for (std::int32_t i = 0; i < n_streams; ++i) {
26+
hdls.insert(cudaStream_t{pool->Next()});
27+
}
28+
ASSERT_EQ(hdls.size(), n_streams);
29+
ASSERT_EQ(hdls.size(), pool->Size());
30+
}
31+
} // namespace xgboost::curt

0 commit comments

Comments
 (0)