Skip to content

Commit bc19741

Browse files
authored
Add wrapper for CUDA host pinned memory pool. (#11451)
1 parent 2835e49 commit bc19741

File tree

8 files changed

+215
-14
lines changed

8 files changed

+215
-14
lines changed

src/common/cuda_pinned_allocator.cu

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
/**
2+
* Copyright 2025, XGBoost Contributors
3+
*/
4+
#include "cuda_pinned_allocator.h"
5+
6+
#if defined(XGBOOST_USE_CUDA)
7+
8+
#include <cuda_runtime_api.h> // for cudaMemPoolCreate, cudaMemPoolDestroy
9+
10+
#include <array> // for array
11+
#include <cstring> // for memset
12+
#include <memory> // for unique_ptr
13+
14+
#endif // defined(XGBOOST_USE_CUDA)
15+
16+
#include "common.h"
17+
#include "cuda_rt_utils.h" // for CurrentDevice
18+
19+
#if CUDART_VERSION >= 12080
20+
#define CUDA_HW_DECOM_AVAILABLE 1
21+
#endif
22+
23+
namespace xgboost::common::cuda_impl {
24+
[[nodiscard]] MemPoolHdl CreateHostMemPool() {
25+
auto mem_pool = std::unique_ptr<cudaMemPool_t, void (*)(cudaMemPool_t*)>{
26+
[] {
27+
cudaMemPoolProps h_props;
28+
std::memset(&h_props, '\0', sizeof(h_props));
29+
auto numa_id = curt::GetNumaId();
30+
h_props.location.id = numa_id;
31+
h_props.location.type = cudaMemLocationTypeHostNuma;
32+
h_props.allocType = cudaMemAllocationTypePinned;
33+
#if defined(CUDA_HW_DECOM_AVAILABLE)
34+
h_props.usage = cudaMemPoolCreateUsageHwDecompress;
35+
#endif // defined(CUDA_HW_DECOM_AVAILABLE)
36+
h_props.handleTypes = cudaMemHandleTypeNone;
37+
38+
cudaMemPoolProps d_props;
39+
std::memset(&d_props, '\0', sizeof(d_props));
40+
auto device_idx = curt::CurrentDevice();
41+
d_props.location.id = device_idx;
42+
d_props.location.type = cudaMemLocationTypeDevice;
43+
d_props.allocType = cudaMemAllocationTypePinned;
44+
#if defined(CUDA_HW_DECOM_AVAILABLE)
45+
d_props.usage = cudaMemPoolCreateUsageHwDecompress;
46+
#endif // defined(CUDA_HW_DECOM_AVAILABLE)
47+
d_props.handleTypes = cudaMemHandleTypeNone;
48+
49+
std::array<cudaMemPoolProps, 2> vprops{h_props, d_props};
50+
51+
cudaMemPool_t* mem_pool = new cudaMemPool_t;
52+
dh::safe_cuda(cudaMemPoolCreate(mem_pool, vprops.data()));
53+
54+
cudaMemAccessDesc h_desc;
55+
h_desc.location = h_props.location;
56+
h_desc.flags = cudaMemAccessFlagsProtReadWrite;
57+
58+
cudaMemAccessDesc d_desc;
59+
d_desc.location = d_props.location;
60+
d_desc.flags = cudaMemAccessFlagsProtReadWrite;
61+
62+
std::array<cudaMemAccessDesc, 2> descs{h_desc, d_desc};
63+
dh::safe_cuda(cudaMemPoolSetAccess(*mem_pool, descs.data(), descs.size()));
64+
return mem_pool;
65+
}(),
66+
[](cudaMemPool_t* mem_pool) {
67+
if (mem_pool) {
68+
dh::safe_cuda(cudaMemPoolDestroy(*mem_pool));
69+
}
70+
}};
71+
return mem_pool;
72+
}
73+
} // namespace xgboost::common::cuda_impl

src/common/cuda_pinned_allocator.h

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
11
/**
2-
* Copyright 2022-2024, XGBoost Contributors
2+
* Copyright 2022-2025, XGBoost Contributors
33
*
44
* @brief cuda pinned allocator for usage with thrust containers
55
*/
6-
76
#pragma once
87

98
#include <cuda_runtime.h>
109

1110
#include <cstddef> // for size_t
1211
#include <limits> // for numeric_limits
12+
#include <memory> // for unique_ptr
1313
#include <new> // for bad_array_new_length
1414

1515
#include "common.h"
@@ -103,6 +103,34 @@ struct SamAllocPolicy {
103103
}
104104
};
105105

106+
/**
107+
* @brief A RAII handle type to the CUDA memory pool.
108+
*/
109+
using MemPoolHdl = std::unique_ptr<cudaMemPool_t, void (*)(cudaMemPool_t*)>;
110+
111+
/**
112+
* @brief Create a CUDA memory pool for allocating host pinned memory.
113+
*/
114+
[[nodiscard]] MemPoolHdl CreateHostMemPool();
115+
116+
/**
117+
* @brief C++ wrapper for the CUDA memory pool.
118+
*/
119+
class HostPinnedMemPool {
120+
MemPoolHdl pool_;
121+
122+
public:
123+
HostPinnedMemPool() : pool_{CreateHostMemPool()} {}
124+
void* AllocateAsync(std::size_t n_bytes, cudaStream_t stream) {
125+
void* ptr = nullptr;
126+
dh::safe_cuda(cudaMallocFromPoolAsync(&ptr, n_bytes, *this->pool_, stream));
127+
return ptr;
128+
}
129+
void DeallocateAsync(void* ptr, cudaStream_t stream) {
130+
dh::safe_cuda(cudaFreeAsync(ptr, stream));
131+
}
132+
};
133+
106134
template <typename T, template <typename> typename Policy>
107135
class CudaHostAllocatorImpl : public Policy<T> {
108136
public:

src/common/cuda_rt_utils.cc

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@
55

66
#if defined(XGBOOST_USE_CUDA)
77
#include <cuda_runtime_api.h>
8-
#endif // defined(XGBOOST_USE_CUDA)
8+
9+
#include <algorithm> // for max
10+
#endif // defined(XGBOOST_USE_CUDA)
911

1012
#include <cstddef> // for size_t
1113
#include <cstdint> // for int32_t
@@ -102,6 +104,13 @@ void DrVersion(std::int32_t* major, std::int32_t* minor) {
102104
GetVersionImpl([](std::int32_t* ver) { dh::safe_cuda(cudaDriverGetVersion(ver)); }, major, minor);
103105
}
104106

107+
[[nodiscard]] std::int32_t GetNumaId() {
108+
std::int32_t numa_id = -1;
109+
dh::safe_cuda(cudaDeviceGetAttribute(&numa_id, cudaDevAttrNumaId, curt::CurrentDevice()));
110+
numa_id = std::max(numa_id, 0);
111+
return numa_id;
112+
}
113+
105114
#else
106115
std::int32_t AllVisibleGPUs() { return 0; }
107116

@@ -125,5 +134,11 @@ void SetDevice(std::int32_t device) {
125134
common::AssertGPUSupport();
126135
}
127136
}
137+
138+
[[nodiscard]] std::int32_t GetNumaId() {
139+
common::AssertGPUSupport();
140+
return 0;
141+
}
142+
128143
#endif // !defined(XGBOOST_USE_CUDA)
129144
} // namespace xgboost::curt

src/common/cuda_rt_utils.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,4 +34,7 @@ void RtVersion(std::int32_t* major, std::int32_t* minor);
3434

3535
// Returns the latest version of CUDA supported by the driver.
3636
void DrVersion(std::int32_t* major, std::int32_t* minor);
37+
38+
// Get the current device's numa ID.
39+
[[nodiscard]] std::int32_t GetNumaId();
3740
} // namespace xgboost::curt

src/common/io.h

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -282,12 +282,13 @@ class ResourceHandler {
282282
public:
283283
// RTTI
284284
enum Kind : std::uint8_t {
285-
kMalloc = 0, // System memory.
286-
kMmap = 1, // Memory mapp.
287-
kCudaMalloc = 2, // CUDA device memory.
288-
kCudaMmap = 3, // CUDA with mmap.
289-
kCudaHostCache = 4, // CUDA pinned host memory.
290-
kCudaGrowOnly = 5, // CUDA virtual memory allocator.
285+
kMalloc = 0, // System memory.
286+
kMmap = 1, // Memory mapp.
287+
kCudaMalloc = 2, // CUDA device memory.
288+
kCudaMmap = 3, // CUDA with mmap.
289+
kCudaHostCache = 4, // CUDA pinned host memory.
290+
kCudaGrowOnly = 5, // CUDA virtual memory allocator.
291+
kCudaPinnedMemPool = 6, // CUDA memory pool for pinned host memory.
291292
};
292293

293294
private:
@@ -316,6 +317,8 @@ class ResourceHandler {
316317
return "CudaHostCache";
317318
case kCudaGrowOnly:
318319
return "CudaGrowOnly";
320+
case kCudaPinnedMemPool:
321+
return "CudaPinnedMemPool";
319322
}
320323
LOG(FATAL) << "Unreachable.";
321324
return {};

src/common/ref_resource_view.cuh

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/**
2-
* Copyright 2024, XGBoost Contributors
2+
* Copyright 2024-2025, XGBoost Contributors
33
*/
44
#pragma once
55

@@ -43,4 +43,14 @@ template <typename T>
4343
auto ref = RefResourceView{resource->DataAs<T>(), n_elements, resource};
4444
return ref;
4545
}
46+
47+
template <typename T>
48+
[[nodiscard]] RefResourceView<T> MakeFixedVecWithPinnedMemPool(
49+
std::shared_ptr<cuda_impl::HostPinnedMemPool> pool, std::size_t n_elements,
50+
dh::CUDAStreamView stream) {
51+
auto resource = std::make_shared<common::HostPinnedMemPoolResource>(
52+
std::move(pool), n_elements * sizeof(T), stream);
53+
auto ref = RefResourceView{resource->DataAs<T>(), n_elements, resource};
54+
return ref;
55+
}
4656
} // namespace xgboost::common

src/common/resource.cuh

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
/**
2-
* Copyright 2024, XGBoost Contributors
2+
* Copyright 2024-2025, XGBoost Contributors
33
*/
44
#pragma once
55
#include <cstddef> // for size_t
66
#include <functional> // for function
7+
#include <utility> // for move
78

8-
#include "cuda_pinned_allocator.h" // for SamAllocator
9+
#include "cuda_pinned_allocator.h" // for SamAllocator, HostPinnedMemPool
10+
#include "device_helpers.cuh" // for CUDAStreamView
911
#include "device_vector.cuh" // for DeviceUVector, GrowOnlyVirtualMemVec
1012
#include "io.h" // for ResourceHandler, MMAPFile
1113
#include "xgboost/string_view.h" // for StringView
@@ -75,6 +77,30 @@ class CudaPinnedResource : public ResourceHandler {
7577
void Resize(std::size_t n_bytes) { this->storage_.resize(n_bytes); }
7678
};
7779

80+
/**
81+
* @brief Resource for fixed-size memory allocated by @ref HostPinnedMemPool.
82+
*
83+
* This container shares the pool but owns the memory.
84+
*/
85+
class HostPinnedMemPoolResource : public ResourceHandler {
86+
std::shared_ptr<cuda_impl::HostPinnedMemPool> pool_;
87+
std::size_t n_bytes_;
88+
dh::CUDAStreamView stream_;
89+
void* ptr_;
90+
91+
public:
92+
explicit HostPinnedMemPoolResource(std::shared_ptr<cuda_impl::HostPinnedMemPool> pool,
93+
std::size_t n_bytes, dh::CUDAStreamView stream)
94+
: ResourceHandler{kCudaPinnedMemPool},
95+
pool_{std::move(pool)},
96+
n_bytes_{n_bytes},
97+
stream_{stream},
98+
ptr_{this->pool_->AllocateAsync(n_bytes, stream)} {}
99+
~HostPinnedMemPoolResource() override { this->pool_->DeallocateAsync(this->ptr_, this->stream_); }
100+
[[nodiscard]] std::size_t Size() const override { return this->n_bytes_; }
101+
[[nodiscard]] void* Data() override { return this->ptr_; }
102+
};
103+
78104
class CudaMmapResource : public ResourceHandler {
79105
std::unique_ptr<MMAPFile, std::function<void(MMAPFile*)>> handle_;
80106
std::size_t n_;

tests/cpp/common/test_ref_resource_view.cu

Lines changed: 45 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/**
2-
* Copyright 2024, XGBoost Contributors
2+
* Copyright 2024-2025, XGBoost Contributors
33
*/
44
#if defined(__linux__)
55

@@ -10,7 +10,8 @@
1010
#include <thrust/sequence.h> // for sequence
1111

1212
#include "../../../src/common/ref_resource_view.cuh"
13-
#include "../helpers.h" // for MakeCUDACtx
13+
#include "../../../src/common/threadpool.h" // for ThreadPool
14+
#include "../helpers.h" // for MakeCUDACtx
1415

1516
namespace xgboost::common {
1617
class TestCudaGrowOnly : public ::testing::TestWithParam<std::size_t> {
@@ -44,6 +45,48 @@ class TestCudaGrowOnly : public ::testing::TestWithParam<std::size_t> {
4445
TEST_P(TestCudaGrowOnly, Resize) { this->Run(this->GetParam()); }
4546

4647
INSTANTIATE_TEST_SUITE_P(RefResourceView, TestCudaGrowOnly, ::testing::Values(1 << 20, 1 << 21));
48+
49+
TEST(HostPinnedMemPool, Alloc) {
50+
std::vector<RefResourceView<double>> refs;
51+
52+
{
53+
// pool goes out of scope before refs does. Test memory safety.
54+
auto pool = std::make_shared<cuda_impl::HostPinnedMemPool>();
55+
for (std::size_t i = 0; i < 4; ++i) {
56+
auto ref = MakeFixedVecWithPinnedMemPool<double>(pool, 128 + i, dh::DefaultStream());
57+
refs.emplace_back(std::move(ref));
58+
}
59+
for (std::size_t i = 0; i < 4; ++i) {
60+
auto const& ref = refs[i];
61+
ASSERT_EQ(ref.size(), 128 + i);
62+
ASSERT_EQ(ref.size_bytes(), ref.size() * sizeof(double));
63+
}
64+
65+
// Thread safety.
66+
auto n_threads = static_cast<std::int32_t>(std::thread::hardware_concurrency());
67+
common::ThreadPool workers{"tmempool", n_threads, [] {
68+
}};
69+
std::vector<std::future<RefResourceView<double>>> alloc_futs;
70+
for (std::int32_t i = 0, n = n_threads * 4; i < n; ++i) {
71+
auto fut = workers.Submit([i, pool] {
72+
auto ref = MakeFixedVecWithPinnedMemPool<double>(pool, 128 + i, dh::DefaultStream());
73+
return ref;
74+
});
75+
alloc_futs.emplace_back(std::move(fut));
76+
}
77+
std::vector<std::future<void>> free_futs(alloc_futs.size());
78+
for (std::int32_t i = 0, n = n_threads * 4; i < n; ++i) {
79+
auto fut = workers.Submit([i, pool, &alloc_futs, &free_futs] {
80+
auto ref = alloc_futs[i].get();
81+
ASSERT_EQ(ref.size(), 128 + i);
82+
});
83+
free_futs[i] = std::move(fut);
84+
}
85+
for (std::int32_t i = 0, n = n_threads * 4; i < n; ++i) {
86+
free_futs[i].get();
87+
}
88+
}
89+
}
4790
} // namespace xgboost::common
4891

4992
#endif // defined(__linux__)

0 commit comments

Comments
 (0)