32
32
llvm::errs () << " '" << #expr << " ' failed with '" << name << " '\n " ; \
33
33
}(expr)
34
34
35
- // Static initialization of CUDA context for device ordinal 0.
36
- static auto InitializeCtx = [] {
35
+ // Static reference to CUDA primary context for device ordinal 0.
36
+ static CUcontext Context = [] {
37
37
CUDA_REPORT_IF_ERROR (cuInit (/* flags=*/ 0 ));
38
38
CUdevice device;
39
39
CUDA_REPORT_IF_ERROR (cuDeviceGet (&device, /* ordinal=*/ 0 ));
40
40
CUcontext context;
41
- CUDA_REPORT_IF_ERROR (cuCtxCreate (&context, /* flags= */ 0 , device));
42
- return 0 ;
41
+ CUDA_REPORT_IF_ERROR (cuDevicePrimaryCtxRetain (&context, device));
42
+ return context ;
43
43
}();
44
44
45
+ // Sets the `Context` for the duration of the instance and restores the previous
46
+ // context on destruction.
47
+ class ScopedContext {
48
+ public:
49
+ ScopedContext () {
50
+ CUDA_REPORT_IF_ERROR (cuCtxGetCurrent (&previous));
51
+ CUDA_REPORT_IF_ERROR (cuCtxSetCurrent (Context));
52
+ }
53
+
54
+ ~ScopedContext () { CUDA_REPORT_IF_ERROR (cuCtxSetCurrent (previous)); }
55
+
56
+ private:
57
+ CUcontext previous;
58
+ };
59
+
45
60
extern " C" CUmodule mgpuModuleLoad (void *data) {
61
+ ScopedContext scopedContext;
46
62
CUmodule module = nullptr ;
47
63
CUDA_REPORT_IF_ERROR (cuModuleLoadData (&module , data));
48
64
return module ;
@@ -66,12 +82,14 @@ extern "C" void mgpuLaunchKernel(CUfunction function, intptr_t gridX,
66
82
intptr_t blockX, intptr_t blockY,
67
83
intptr_t blockZ, int32_t smem, CUstream stream,
68
84
void **params, void **extra) {
85
+ ScopedContext scopedContext;
69
86
CUDA_REPORT_IF_ERROR (cuLaunchKernel (function, gridX, gridY, gridZ, blockX,
70
87
blockY, blockZ, smem, stream, params,
71
88
extra));
72
89
}
73
90
74
91
extern " C" CUstream mgpuStreamCreate () {
92
+ ScopedContext scopedContext;
75
93
CUstream stream = nullptr ;
76
94
CUDA_REPORT_IF_ERROR (cuStreamCreate (&stream, CU_STREAM_NON_BLOCKING));
77
95
return stream;
@@ -90,6 +108,7 @@ extern "C" void mgpuStreamWaitEvent(CUstream stream, CUevent event) {
90
108
}
91
109
92
110
extern " C" CUevent mgpuEventCreate () {
111
+ ScopedContext scopedContext;
93
112
CUevent event = nullptr ;
94
113
CUDA_REPORT_IF_ERROR (cuEventCreate (&event, CU_EVENT_DISABLE_TIMING));
95
114
return event;
@@ -108,6 +127,7 @@ extern "C" void mgpuEventRecord(CUevent event, CUstream stream) {
108
127
}
109
128
110
129
extern " C" void *mgpuMemAlloc (uint64_t sizeBytes, CUstream /* stream*/ ) {
130
+ ScopedContext scopedContext;
111
131
CUdeviceptr ptr;
112
132
CUDA_REPORT_IF_ERROR (cuMemAlloc (&ptr, sizeBytes));
113
133
return reinterpret_cast <void *>(ptr);
0 commit comments