|
15 | 15 |
|
16 | 16 | #include "xgboost/string_view.h" // for StringView
|
17 | 17 |
|
| 18 | +#if CUDART_VERSION >= 12080 |
| 19 | +#define CUDA_HW_DECOM_AVAILABLE 1 |
| 20 | +#endif |
| 21 | + |
18 | 22 | namespace xgboost::cudr {
|
19 | 23 | /**
|
20 | 24 | * @brief A struct for retrieving CUDA driver API from the runtime API.
|
@@ -44,28 +48,39 @@ struct CuDriverApi {
|
44 | 48 | using DeviceGetAttribute = CUresult(int *pi, CUdevice_attribute attrib, CUdevice dev);
|
45 | 49 | using DeviceGet = CUresult(CUdevice *device, int ordinal);
|
46 | 50 |
|
| 51 | +#if defined(CUDA_HW_DECOM_AVAILABLE) |
| 52 | + using BatchDecompressAsync = CUresult(CUmemDecompressParams *paramsArray, size_t count, |
| 53 | + unsigned int flags, size_t *errorIndex, CUstream stream); |
| 54 | +#endif // defined(CUDA_HW_DECOM_AVAILABLE) |
| 55 | + |
47 | 56 | MemGetAllocationGranularityFn *cuMemGetAllocationGranularity{nullptr}; // NOLINT
|
48 | 57 | MemCreateFn *cuMemCreate{nullptr}; // NOLINT
|
49 | 58 | /**
|
50 | 59 | * @param[in] offset - Must be zero.
|
51 | 60 | */
|
52 |
| - MemMapFn *cuMemMap{nullptr}; // NOLINT |
| 61 | + MemMapFn *cuMemMap{nullptr}; // NOLINT |
53 | 62 | /**
|
54 | 63 | * @param[out] ptr - Resulting pointer to start of virtual address range allocated
|
55 | 64 | * @param[in] size - Size of the reserved virtual address range requested
|
56 | 65 | * @param[in] alignment - Alignment of the reserved virtual address range requested
|
57 | 66 | * @param[in] addr - Fixed starting address range requested
|
58 | 67 | * @param[in] flags - Currently unused, must be zero
|
59 | 68 | */
|
60 |
| - MemAddressReserveFn *cuMemAddressReserve{nullptr}; // NOLINT |
61 |
| - MemSetAccessFn *cuMemSetAccess{nullptr}; // NOLINT |
62 |
| - MemUnmapFn *cuMemUnmap{nullptr}; // NOLINT |
63 |
| - MemReleaseFn *cuMemRelease{nullptr}; // NOLINT |
64 |
| - MemAddressFreeFn *cuMemAddressFree{nullptr}; // NOLINT |
65 |
| - GetErrorString *cuGetErrorString{nullptr}; // NOLINT |
66 |
| - GetErrorName *cuGetErrorName{nullptr}; // NOLINT |
67 |
| - DeviceGetAttribute *cuDeviceGetAttribute{nullptr}; // NOLINT |
68 |
| - DeviceGet *cuDeviceGet{nullptr}; // NOLINT |
| 69 | + MemAddressReserveFn *cuMemAddressReserve{nullptr}; // NOLINT |
| 70 | + MemSetAccessFn *cuMemSetAccess{nullptr}; // NOLINT |
| 71 | + MemUnmapFn *cuMemUnmap{nullptr}; // NOLINT |
| 72 | + MemReleaseFn *cuMemRelease{nullptr}; // NOLINT |
| 73 | + MemAddressFreeFn *cuMemAddressFree{nullptr}; // NOLINT |
| 74 | + GetErrorString *cuGetErrorString{nullptr}; // NOLINT |
| 75 | + GetErrorName *cuGetErrorName{nullptr}; // NOLINT |
| 76 | + DeviceGetAttribute *cuDeviceGetAttribute{nullptr}; // NOLINT |
| 77 | + DeviceGet *cuDeviceGet{nullptr}; // NOLINT |
| 78 | + |
| 79 | +#if defined(CUDA_HW_DECOM_AVAILABLE) |
| 80 | + |
| 81 | + BatchDecompressAsync *cuMemBatchDecompressAsync{nullptr}; // NOLINT |
| 82 | + |
| 83 | +#endif // defined(CUDA_HW_DECOM_AVAILABLE) |
69 | 84 |
|
70 | 85 | CuDriverApi();
|
71 | 86 |
|
@@ -96,7 +111,7 @@ inline auto GetAllocGranularity(CUmemAllocationProp const *prop) {
|
96 | 111 | /**
|
97 | 112 | * @brief Obtain appropriate device ordinal for `CUmemLocation`.
|
98 | 113 | */
|
99 |
| -void MakeCuMemLocation(CUmemLocationType type, CUmemLocation* loc); |
| 114 | +void MakeCuMemLocation(CUmemLocationType type, CUmemLocation *loc); |
100 | 115 |
|
101 | 116 | /**
|
102 | 117 | * @brief Construct a `CUmemAllocationProp`.
|
|
0 commit comments