Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Commit c0fb4a1

Browse files
Force device synchronization before CUDA module loading
This commit forces device synchronization before loading the cuda module. Synchronization at this place helps prevent a subtle issue that can occur with multi-threading and autotuning. The issue appeared in PyTorch integration during the backward phase only but may appear with other frameworks too. The issue came from the way PyTorch switches CPU threads for computing backward without immediately initializing the CUDA context. In such situations the tuner may kick in and cuModuleLoadDataEx would get called on a CPU thread on which the CUDA context was not previously initialized resulting in a hard unrecoverable error. Forcing synchronization calls a CUDA runtime API function (cudaDeviceSynchronize()) which has the side effect of initializing the CUDA context. Granted the implicit nature of the is not ideal this is a CUDA-ism. In the same way the PyTorch-ism of switching thread without initializing the CUDA context requires lazy on-demand initialization. Putting this initialization inside cuda_rtc.cc is future proof and will not require us to screw around when the problem appears in the future with other frameworks.
1 parent 35e08bb commit c0fb4a1

File tree

1 file changed

+19
-0
lines changed

1 file changed

+19
-0
lines changed

tc/core/cuda/cuda_rtc.cc

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,18 @@ void CudaRTCFunction::clear() {
4848
}
4949
}
5050

51+
void checkOrCreateContext() {
52+
static thread_local bool created = false;
53+
if (!created) {
54+
created = true;
55+
CUcontext ctx;
56+
TC_CUDA_DRIVERAPI_ENFORCE(cuCtxGetCurrent(&ctx));
57+
if (!ctx) {
58+
TC_CUDA_RUNTIMEAPI_ENFORCE(cudaDeviceSynchronize());
59+
}
60+
}
61+
}
62+
5163
std::unique_ptr<CudaRTCFunction> CudaRTCFunction::Compile(
5264
const std::string& name,
5365
const std::string& source) {
@@ -143,6 +155,13 @@ Duration CudaRTCFunction::Launch(
143155
if (perGpuModule_.count(dev) == 0) {
144156
CUmodule module;
145157
CUfunction function;
158+
// Checking that a CUDA context exists for the current thread is necessary
159+
// when benchmarking the backward of a PyTorch gradient operator:
160+
// the backward is called on a different thread whose context may not have
161+
// been initialized explicitly.
162+
// This call to cudaDeviceSynchronize implicitly creates a new context if
163+
// one is not bound to the current CPU.
164+
checkOrCreateContext();
146165
TC_CUDA_DRIVERAPI_ENFORCE(
147166
cuModuleLoadDataEx(&module, nvrtc_ptx.data(), 0, 0, 0));
148167
perGpuModule_.emplace(dev, module);

0 commit comments

Comments
 (0)