Skip to content

Commit 29d1134

Browse files
chsiggmemfrob
authored andcommitted
[mlir] Set CUDA/ROCm context before creating resources.
The current context is thread-local state, and in preparation of GPU async execution (on multiple threads) we need to set the context before calling API that create resources. Reviewed By: herhut Differential Revision: https://reviews.llvm.org/D94495
1 parent 59a7ae9 commit 29d1134

File tree

2 files changed

+48
-8
lines changed

2 files changed

+48
-8
lines changed

mlir/tools/mlir-cuda-runner/cuda-runtime-wrappers.cpp

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,17 +32,33 @@
3232
llvm::errs() << "'" << #expr << "' failed with '" << name << "'\n"; \
3333
}(expr)
3434

35-
// Static initialization of CUDA context for device ordinal 0.
36-
static auto InitializeCtx = [] {
35+
// Static reference to CUDA primary context for device ordinal 0.
36+
static CUcontext Context = [] {
3737
CUDA_REPORT_IF_ERROR(cuInit(/*flags=*/0));
3838
CUdevice device;
3939
CUDA_REPORT_IF_ERROR(cuDeviceGet(&device, /*ordinal=*/0));
4040
CUcontext context;
41-
CUDA_REPORT_IF_ERROR(cuCtxCreate(&context, /*flags=*/0, device));
42-
return 0;
41+
CUDA_REPORT_IF_ERROR(cuDevicePrimaryCtxRetain(&context, device));
42+
return context;
4343
}();
4444

45+
// Sets the `Context` for the duration of the instance and restores the previous
46+
// context on destruction.
47+
class ScopedContext {
48+
public:
49+
ScopedContext() {
50+
CUDA_REPORT_IF_ERROR(cuCtxGetCurrent(&previous));
51+
CUDA_REPORT_IF_ERROR(cuCtxSetCurrent(Context));
52+
}
53+
54+
~ScopedContext() { CUDA_REPORT_IF_ERROR(cuCtxSetCurrent(previous)); }
55+
56+
private:
57+
CUcontext previous;
58+
};
59+
4560
extern "C" CUmodule mgpuModuleLoad(void *data) {
61+
ScopedContext scopedContext;
4662
CUmodule module = nullptr;
4763
CUDA_REPORT_IF_ERROR(cuModuleLoadData(&module, data));
4864
return module;
@@ -66,12 +82,14 @@ extern "C" void mgpuLaunchKernel(CUfunction function, intptr_t gridX,
6682
intptr_t blockX, intptr_t blockY,
6783
intptr_t blockZ, int32_t smem, CUstream stream,
6884
void **params, void **extra) {
85+
ScopedContext scopedContext;
6986
CUDA_REPORT_IF_ERROR(cuLaunchKernel(function, gridX, gridY, gridZ, blockX,
7087
blockY, blockZ, smem, stream, params,
7188
extra));
7289
}
7390

7491
extern "C" CUstream mgpuStreamCreate() {
92+
ScopedContext scopedContext;
7593
CUstream stream = nullptr;
7694
CUDA_REPORT_IF_ERROR(cuStreamCreate(&stream, CU_STREAM_NON_BLOCKING));
7795
return stream;
@@ -90,6 +108,7 @@ extern "C" void mgpuStreamWaitEvent(CUstream stream, CUevent event) {
90108
}
91109

92110
extern "C" CUevent mgpuEventCreate() {
111+
ScopedContext scopedContext;
93112
CUevent event = nullptr;
94113
CUDA_REPORT_IF_ERROR(cuEventCreate(&event, CU_EVENT_DISABLE_TIMING));
95114
return event;
@@ -108,6 +127,7 @@ extern "C" void mgpuEventRecord(CUevent event, CUstream stream) {
108127
}
109128

110129
extern "C" void *mgpuMemAlloc(uint64_t sizeBytes, CUstream /*stream*/) {
130+
ScopedContext scopedContext;
111131
CUdeviceptr ptr;
112132
CUDA_REPORT_IF_ERROR(cuMemAlloc(&ptr, sizeBytes));
113133
return reinterpret_cast<void *>(ptr);

mlir/tools/mlir-rocm-runner/rocm-runtime-wrappers.cpp

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,17 +31,33 @@
3131
llvm::errs() << "'" << #expr << "' failed with '" << name << "'\n"; \
3232
}(expr)
3333

34-
// Static initialization of HIP context for device ordinal 0.
35-
static auto InitializeCtx = [] {
34+
// Static reference to HIP primary context for device ordinal 0.
35+
static hipCtx_t Context = [] {
3636
HIP_REPORT_IF_ERROR(hipInit(/*flags=*/0));
3737
hipDevice_t device;
3838
HIP_REPORT_IF_ERROR(hipDeviceGet(&device, /*ordinal=*/0));
3939
hipCtx_t context;
40-
HIP_REPORT_IF_ERROR(hipCtxCreate(&context, /*flags=*/0, device));
41-
return 0;
40+
HIP_REPORT_IF_ERROR(hipDevicePrimaryCtxRetain(&context, device));
41+
return context;
4242
}();
4343

44+
// Sets the `Context` for the duration of the instance and restores the previous
45+
// context on destruction.
46+
class ScopedContext {
47+
public:
48+
ScopedContext() {
49+
HIP_REPORT_IF_ERROR(hipCtxGetCurrent(&previous));
50+
HIP_REPORT_IF_ERROR(hipCtxSetCurrent(Context));
51+
}
52+
53+
~ScopedContext() { HIP_REPORT_IF_ERROR(hipCtxSetCurrent(previous)); }
54+
55+
private:
56+
hipCtx_t previous;
57+
};
58+
4459
extern "C" hipModule_t mgpuModuleLoad(void *data) {
60+
ScopedContext scopedContext;
4561
hipModule_t module = nullptr;
4662
HIP_REPORT_IF_ERROR(hipModuleLoadData(&module, data));
4763
return module;
@@ -67,12 +83,14 @@ extern "C" void mgpuLaunchKernel(hipFunction_t function, intptr_t gridX,
6783
intptr_t blockZ, int32_t smem,
6884
hipStream_t stream, void **params,
6985
void **extra) {
86+
ScopedContext scopedContext;
7087
HIP_REPORT_IF_ERROR(hipModuleLaunchKernel(function, gridX, gridY, gridZ,
7188
blockX, blockY, blockZ, smem,
7289
stream, params, extra));
7390
}
7491

7592
extern "C" hipStream_t mgpuStreamCreate() {
93+
ScopedContext scopedContext;
7694
hipStream_t stream = nullptr;
7795
HIP_REPORT_IF_ERROR(hipStreamCreate(&stream));
7896
return stream;
@@ -91,6 +109,7 @@ extern "C" void mgpuStreamWaitEvent(hipStream_t stream, hipEvent_t event) {
91109
}
92110

93111
extern "C" hipEvent_t mgpuEventCreate() {
112+
ScopedContext scopedContext;
94113
hipEvent_t event = nullptr;
95114
HIP_REPORT_IF_ERROR(hipEventCreateWithFlags(&event, hipEventDisableTiming));
96115
return event;
@@ -109,6 +128,7 @@ extern "C" void mgpuEventRecord(hipEvent_t event, hipStream_t stream) {
109128
}
110129

111130
extern "C" void *mgpuMemAlloc(uint64_t sizeBytes, hipStream_t /*stream*/) {
131+
ScopedContext scopedContext;
112132
void *ptr;
113133
HIP_REPORT_IF_ERROR(hipMalloc(&ptr, sizeBytes));
114134
return ptr;

0 commit comments

Comments
 (0)