Skip to content

Commit bcf0b20

Browse files
authored
Check nvidia smi version before loading decomp API. (#11464)
1 parent ff56568 commit bcf0b20

File tree

4 files changed

+63
-26
lines changed

4 files changed

+63
-26
lines changed

src/common/cuda_dr_utils.cc

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
#include "xgboost/string_view.h" // for StringView
1919

2020
namespace xgboost::cudr {
21-
CuDriverApi::CuDriverApi() {
21+
CuDriverApi::CuDriverApi(std::int32_t cu_major, std::int32_t cu_minor, std::int32_t kdm_major) {
2222
// similar to dlopen, but without the need to release a handle.
2323
auto safe_load = [](xgboost::StringView name, auto **fnptr) {
2424
cudaDriverEntryPointQueryResult status;
@@ -41,7 +41,12 @@ CuDriverApi::CuDriverApi() {
4141
safe_load("cuDeviceGetAttribute", &this->cuDeviceGetAttribute);
4242
safe_load("cuDeviceGet", &this->cuDeviceGet);
4343
#if defined(CUDA_HW_DECOM_AVAILABLE)
44-
safe_load("cuMemBatchDecompressAsync", &this->cuMemBatchDecompressAsync);
44+
// CTK 12.8
45+
if (((cu_major == 12 && cu_minor >= 8) || cu_major > 12) && (kdm_major >= 570)) {
46+
safe_load("cuMemBatchDecompressAsync", &this->cuMemBatchDecompressAsync);
47+
} else {
48+
this->cuMemBatchDecompressAsync = nullptr;
49+
}
4550
#endif // defined(CUDA_HW_DECOM_AVAILABLE)
4651
CHECK(this->cuMemGetAllocationGranularity);
4752
}
@@ -76,9 +81,17 @@ void CuDriverApi::ThrowIfError(CUresult status, StringView fn, std::int32_t line
7681
}
7782

7883
[[nodiscard]] CuDriverApi &GetGlobalCuDriverApi() {
84+
std::int32_t cu_major = -1, cu_minor = -1;
85+
GetDrVersionGlobal(&cu_major, &cu_minor);
86+
87+
std::int32_t kdm_major = -1, kdm_minor = -1;
88+
if (!GetVersionFromSmiGlobal(&kdm_major, &kdm_minor)) {
89+
kdm_major = -1;
90+
}
91+
7992
static std::once_flag flag;
8093
static std::unique_ptr<CuDriverApi> cu;
81-
std::call_once(flag, [&] { cu = std::make_unique<CuDriverApi>(); });
94+
std::call_once(flag, [&] { cu = std::make_unique<CuDriverApi>(cu_major, cu_minor, kdm_major); });
8295
return *cu;
8396
}
8497

@@ -154,5 +167,24 @@ void MakeCuMemLocation(CUmemLocationType type, CUmemLocation *loc) {
154167

155168
return Invalid();
156169
}
170+
171+
[[nodiscard]] bool GetVersionFromSmiGlobal(std::int32_t *p_major, std::int32_t *p_minor) {
172+
static std::once_flag flag;
173+
static std::int32_t major = -1, minor = -1;
174+
static bool result = false;
175+
std::call_once(flag, [&] { result = GetVersionFromSmi(&major, &minor); });
176+
177+
*p_major = major;
178+
*p_minor = minor;
179+
return result;
180+
}
181+
182+
void GetDrVersionGlobal(std::int32_t *p_major, std::int32_t *p_minor) {
183+
static std::once_flag once;
184+
static std::int32_t major{0}, minor{0};
185+
std::call_once(once, [] { xgboost::curt::DrVersion(&major, &minor); });
186+
*p_major = major;
187+
*p_minor = minor;
188+
}
157189
} // namespace xgboost::cudr
158190
#endif

src/common/cuda_dr_utils.h

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ struct CuDriverApi {
8282

8383
#endif // defined(CUDA_HW_DECOM_AVAILABLE)
8484

85-
CuDriverApi();
85+
CuDriverApi(std::int32_t cu_major, std::int32_t cu_minor, std::int32_t kdm_major);
8686

8787
void ThrowIfError(CUresult status, StringView fn, std::int32_t line, char const *file) const;
8888
};
@@ -124,4 +124,14 @@ void MakeCuMemLocation(CUmemLocationType type, CUmemLocation *loc);
124124
* @return Whether the system call is successful.
125125
*/
126126
[[nodiscard]] bool GetVersionFromSmi(std::int32_t *p_major, std::int32_t *p_minor);
127+
128+
/**
129+
* @brief Cache the result from @ref GetVersionFromSmi in a global variable
130+
*/
131+
[[nodiscard]] bool GetVersionFromSmiGlobal(std::int32_t *p_major, std::int32_t *p_minor);
132+
133+
/**
134+
* @brief Cache the result from @ref DrVersion in a global variable
135+
*/
136+
void GetDrVersionGlobal(std::int32_t *p_major, std::int32_t *p_minor);
127137
} // namespace xgboost::cudr

src/common/device_compression.cu

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,8 @@ void DecompressSnappy(dh::CUDAStreamView stream, SnappyDecomprMgr const& mgr,
242242
CHECK(out.empty());
243243
return;
244244
}
245-
if (GetGlobalDeStatus().avail) {
245+
if (GetGlobalDeStatus().avail &&
246+
cudr::GetGlobalCuDriverApi().cuMemBatchDecompressAsync != nullptr) {
246247
// Invoke the DE.
247248
#if defined(CUDA_HW_DECOM_AVAILABLE)
248249
std::size_t error_index;

src/common/device_helpers.cu

Lines changed: 15 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,7 @@
11
/**
22
* Copyright 2024-2025, XGBoost contributors
33
*/
4-
#include <mutex> // for once_flag, call_once
5-
64
#include "../common/cuda_dr_utils.h" // for GetVersionFromSmi
7-
#include "cuda_rt_utils.h" // for RtVersion
85
#include "device_helpers.cuh"
96
#include "device_vector.cuh" // for GrowOnlyVirtualMemVec
107
#include "xgboost/windefs.h" // for xgboost_IS_WIN
@@ -18,25 +15,22 @@ namespace {
1815
// Check whether cuda virtual memory can be used.
1916
// Host NUMA allocation requires driver that supports CTK >= 12.5 to be stable
2017
[[nodiscard]] bool CheckVmAlloc() {
21-
static bool vm_flag = true;
22-
static std::once_flag once;
18+
std::int32_t major{0}, minor{0};
19+
xgboost::cudr::GetDrVersionGlobal(&major, &minor);
2320

24-
std::call_once(once, [] {
25-
std::int32_t major{0}, minor{0};
26-
xgboost::curt::DrVersion(&major, &minor);
27-
if (IsSupportedDrVer(major, minor)) {
28-
// The result from the driver api is not reliable. The system driver might not match
29-
// the CUDA driver in some obscure cases.
30-
//
31-
// https://docs.nvidia.com/cuda/cuda-toolkit-release-notes/index.html
32-
// Ver Linux Win
33-
// CUDA 12.5 Update 1 >=555.42.06 >=555.85
34-
// CUDA 12.5 GA >=555.42.02 >=555.85
35-
vm_flag = xgboost::cudr::GetVersionFromSmi(&major, &minor) && major >= 555;
36-
} else {
37-
vm_flag = false;
38-
}
39-
});
21+
bool vm_flag = true;
22+
if (IsSupportedDrVer(major, minor)) {
23+
// The result from the driver api is not reliable. The system driver might not match
24+
// the CUDA driver in some obscure cases.
25+
//
26+
// https://docs.nvidia.com/cuda/cuda-toolkit-release-notes/index.html
27+
// Ver Linux Win
28+
// CUDA 12.5 Update 1 >=555.42.06 >=555.85
29+
// CUDA 12.5 GA >=555.42.02 >=555.85
30+
vm_flag = xgboost::cudr::GetVersionFromSmiGlobal(&major, &minor) && major >= 555;
31+
} else {
32+
vm_flag = false;
33+
}
4034
return vm_flag;
4135
}
4236
} // namespace

0 commit comments

Comments
 (0)