Skip to content

Commit 3a08d6c

Browse files
committed
use current ctx and dev by default in CUDA prov
1 parent 7727ad1 commit 3a08d6c

File tree

5 files changed

+92
-7
lines changed

5 files changed

+92
-7
lines changed

include/umf/providers/provider_cuda.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@ typedef struct umf_cuda_memory_provider_params_t
2020
*umf_cuda_memory_provider_params_handle_t;
2121

2222
/// @brief Create a struct to store parameters of the CUDA Memory Provider.
23-
/// @param hParams [out] handle to the newly created parameters struct.
23+
/// @param hParams [out] handle to the newly created parameters structure,
24+
/// initialized with the default (current) context and device ID.
2425
/// @return UMF_RESULT_SUCCESS on success or appropriate error code on failure.
2526
umf_result_t umfCUDAMemoryProviderParamsCreate(
2627
umf_cuda_memory_provider_params_handle_t *hParams);

src/provider/provider_cuda.c

Lines changed: 35 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,7 @@ typedef struct cu_ops_t {
139139
CUresult (*cuGetErrorName)(CUresult error, const char **pStr);
140140
CUresult (*cuGetErrorString)(CUresult error, const char **pStr);
141141
CUresult (*cuCtxGetCurrent)(CUcontext *pctx);
142+
CUresult (*cuCtxGetDevice)(CUdevice *device);
142143
CUresult (*cuCtxSetCurrent)(CUcontext ctx);
143144
CUresult (*cuIpcGetMemHandle)(CUipcMemHandle *pHandle, CUdeviceptr dptr);
144145
CUresult (*cuIpcOpenMemHandle)(CUdeviceptr *pdptr, CUipcMemHandle handle,
@@ -221,6 +222,8 @@ static void init_cu_global_state(void) {
221222
utils_get_symbol_addr(lib_handle, "cuGetErrorString", lib_name);
222223
*(void **)&g_cu_ops.cuCtxGetCurrent =
223224
utils_get_symbol_addr(lib_handle, "cuCtxGetCurrent", lib_name);
225+
*(void **)&g_cu_ops.cuCtxGetDevice =
226+
utils_get_symbol_addr(lib_handle, "cuCtxGetDevice", lib_name);
224227
*(void **)&g_cu_ops.cuCtxSetCurrent =
225228
utils_get_symbol_addr(lib_handle, "cuCtxSetCurrent", lib_name);
226229
*(void **)&g_cu_ops.cuIpcGetMemHandle =
@@ -234,9 +237,9 @@ static void init_cu_global_state(void) {
234237
!g_cu_ops.cuMemHostAlloc || !g_cu_ops.cuMemAllocManaged ||
235238
!g_cu_ops.cuMemFree || !g_cu_ops.cuMemFreeHost ||
236239
!g_cu_ops.cuGetErrorName || !g_cu_ops.cuGetErrorString ||
237-
!g_cu_ops.cuCtxGetCurrent || !g_cu_ops.cuCtxSetCurrent ||
238-
!g_cu_ops.cuIpcGetMemHandle || !g_cu_ops.cuIpcOpenMemHandle ||
239-
!g_cu_ops.cuIpcCloseMemHandle) {
240+
!g_cu_ops.cuCtxGetCurrent || !g_cu_ops.cuCtxGetDevice ||
241+
!g_cu_ops.cuCtxSetCurrent || !g_cu_ops.cuIpcGetMemHandle ||
242+
!g_cu_ops.cuIpcOpenMemHandle || !g_cu_ops.cuIpcCloseMemHandle) {
240243
LOG_FATAL("Required CUDA symbols not found.");
241244
Init_cu_global_state_failed = true;
242245
utils_close_library(lib_handle);
@@ -260,8 +263,29 @@ umf_result_t umfCUDAMemoryProviderParamsCreate(
260263
return UMF_RESULT_ERROR_OUT_OF_HOST_MEMORY;
261264
}
262265

263-
params_data->cuda_context_handle = NULL;
264-
params_data->cuda_device_handle = -1;
266+
utils_init_once(&cu_is_initialized, init_cu_global_state);
267+
if (Init_cu_global_state_failed) {
268+
LOG_FATAL("Loading CUDA symbols failed");
269+
return UMF_RESULT_ERROR_DEPENDENCY_UNAVAILABLE;
270+
}
271+
272+
// initialize context and device to the current ones
273+
CUcontext current_ctx = NULL;
274+
CUresult cu_result = g_cu_ops.cuCtxGetCurrent(&current_ctx);
275+
if (cu_result == CUDA_SUCCESS) {
276+
params_data->cuda_context_handle = current_ctx;
277+
} else {
278+
params_data->cuda_context_handle = NULL;
279+
}
280+
281+
CUdevice current_device = -1;
282+
cu_result = g_cu_ops.cuCtxGetDevice(&current_device);
283+
if (cu_result == CUDA_SUCCESS) {
284+
params_data->cuda_device_handle = current_device;
285+
} else {
286+
params_data->cuda_device_handle = -1;
287+
}
288+
265289
params_data->memory_type = UMF_MEMORY_TYPE_UNKNOWN;
266290
params_data->alloc_flags = 0;
267291

@@ -342,6 +366,12 @@ static umf_result_t cu_memory_provider_initialize(void *params,
342366
}
343367

344368
if (cu_params->cuda_context_handle == NULL) {
369+
LOG_ERR("Invalid context handle");
370+
return UMF_RESULT_ERROR_INVALID_ARGUMENT;
371+
}
372+
373+
if (cu_params->cuda_device_handle < 0) {
374+
LOG_ERR("Invalid device handle");
345375
return UMF_RESULT_ERROR_INVALID_ARGUMENT;
346376
}
347377

test/providers/cuda_helpers.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -412,6 +412,18 @@ CUcontext get_mem_context(void *ptr) {
412412
return context;
413413
}
414414

415+
int get_mem_device(void *ptr) {
416+
int device;
417+
CUresult res = libcu_ops.cuPointerGetAttribute(
418+
&device, CU_POINTER_ATTRIBUTE_DEVICE_ORDINAL, (CUdeviceptr)ptr);
419+
if (res != CUDA_SUCCESS) {
420+
fprintf(stderr, "cuPointerGetAttribute() failed!\n");
421+
return -1;
422+
}
423+
424+
return device;
425+
}
426+
415427
CUcontext get_current_context() {
416428
CUcontext context;
417429
CUresult res = libcu_ops.cuCtxGetCurrent(&context);

test/providers/cuda_helpers.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,8 @@ unsigned int get_mem_host_alloc_flags(void *ptr);
4848

4949
CUcontext get_mem_context(void *ptr);
5050

51+
int get_mem_device(void *ptr);
52+
5153
CUcontext get_current_context();
5254

5355
#ifdef __cplusplus

test/providers/provider_cuda.cpp

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,14 +142,15 @@ struct umfCUDAProviderTest
142142

143143
memAccessor = nullptr;
144144
expected_context = cudaTestHelper.get_test_context();
145+
expected_device = cudaTestHelper.get_test_device();
145146
params = create_cuda_prov_params(cudaTestHelper.get_test_context(),
146147
cudaTestHelper.get_test_device(),
147148
memory_type, 0 /* alloc flags */);
148149
ASSERT_NE(expected_context, nullptr);
150+
ASSERT_GE(expected_device, 0);
149151

150152
switch (memory_type) {
151153
case UMF_MEMORY_TYPE_DEVICE:
152-
153154
memAccessor = std::make_unique<CUDAMemoryAccessor>(
154155
cudaTestHelper.get_test_context(),
155156
cudaTestHelper.get_test_device());
@@ -178,6 +179,7 @@ struct umfCUDAProviderTest
178179

179180
std::unique_ptr<MemoryAccessor> memAccessor = nullptr;
180181
CUcontext expected_context = nullptr;
182+
int expected_device = -1;
181183
umf_usm_memory_type_t expected_memory_type;
182184
};
183185

@@ -328,6 +330,44 @@ TEST_P(umfCUDAProviderTest, getPageSizeInvalidArgs) {
328330
umfMemoryProviderDestroy(provider);
329331
}
330332

333+
TEST_P(umfCUDAProviderTest, cudaProviderDefaultParams) {
334+
umf_cuda_memory_provider_params_handle_t defaultParams = nullptr;
335+
umf_result_t umf_result = umfCUDAMemoryProviderParamsCreate(&defaultParams);
336+
ASSERT_EQ(umf_result, UMF_RESULT_SUCCESS);
337+
338+
umf_result = umfCUDAMemoryProviderParamsSetMemoryType(defaultParams,
339+
expected_memory_type);
340+
ASSERT_EQ(umf_result, UMF_RESULT_SUCCESS);
341+
342+
// NOTE: we intentionally do not set any context and device params
343+
344+
umf_memory_provider_handle_t provider = nullptr;
345+
umf_result = umfMemoryProviderCreate(umfCUDAMemoryProviderOps(),
346+
defaultParams, &provider);
347+
ASSERT_EQ(umf_result, UMF_RESULT_SUCCESS);
348+
ASSERT_NE(provider, nullptr);
349+
350+
// do single alloc and check if the context and device id of allocated
351+
// memory are correct
352+
353+
void *ptr = nullptr;
354+
umf_result = umfMemoryProviderAlloc(provider, 128, 0, &ptr);
355+
ASSERT_EQ(umf_result, UMF_RESULT_SUCCESS);
356+
ASSERT_NE(ptr, nullptr);
357+
358+
CUcontext actual_mem_context = get_mem_context(ptr);
359+
ASSERT_EQ(actual_mem_context, expected_context);
360+
361+
int actual_device = get_mem_device(ptr);
362+
ASSERT_EQ(actual_device, expected_device);
363+
364+
umf_result = umfMemoryProviderFree(provider, ptr, 128);
365+
ASSERT_EQ(umf_result, UMF_RESULT_SUCCESS);
366+
367+
umfMemoryProviderDestroy(provider);
368+
umfCUDAMemoryProviderParamsDestroy(defaultParams);
369+
}
370+
331371
TEST_P(umfCUDAProviderTest, cudaProviderNullParams) {
332372
umf_result_t res = umfCUDAMemoryProviderParamsCreate(nullptr);
333373
EXPECT_EQ(res, UMF_RESULT_ERROR_INVALID_ARGUMENT);

0 commit comments

Comments
 (0)