Skip to content

Commit cc790b8

Browse files
authored
Uses torch.version.cuda to compile CUDA extensions (#2163)
1 parent 43a65e2 commit cc790b8

File tree

1 file changed

+5
-7
lines changed

1 file changed

+5
-7
lines changed

setup.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -253,23 +253,21 @@ def get_extensions():
253253
if debug_mode:
254254
print("Compiling in debug mode")
255255

256-
if not torch.cuda.is_available():
256+
if not torch.version.cuda:
257257
print(
258258
"PyTorch GPU support is not available. Skipping compilation of CUDA extensions"
259259
)
260-
if CUDA_HOME is None and torch.cuda.is_available() and torch.version.cuda:
260+
if CUDA_HOME is None and torch.version.cuda:
261261
print("CUDA toolkit is not available. Skipping compilation of CUDA extensions")
262262
print(
263263
"If you'd like to compile CUDA extensions locally please install the cudatoolkit from https://anaconda.org/nvidia/cuda-toolkit"
264264
)
265-
if ROCM_HOME is None and torch.cuda.is_available() and torch.version.hip:
265+
if ROCM_HOME is None and torch.version.hip:
266266
print("ROCm is not available. Skipping compilation of ROCm extensions")
267267
print("If you'd like to compile ROCm extensions locally please install ROCm")
268268

269-
use_cuda = (
270-
torch.cuda.is_available() and torch.version.cuda and CUDA_HOME is not None
271-
)
272-
use_hip = torch.cuda.is_available() and torch.version.hip and ROCM_HOME is not None
269+
use_cuda = torch.version.cuda and CUDA_HOME is not None
270+
use_hip = torch.version.hip and ROCM_HOME is not None
273271
extension = CUDAExtension if (use_cuda or use_hip) else CppExtension
274272

275273
nvcc_args = [

0 commit comments

Comments
 (0)