1
- # Copyright (c) 2020-2024 , NVIDIA CORPORATION.
1
+ # Copyright (c) 2020-2025 , NVIDIA CORPORATION.
2
2
3
3
4
4
def validate_setup ():
@@ -15,7 +15,10 @@ def validate_setup():
15
15
16
16
import warnings
17
17
18
- from cuda .bindings .runtime import cudaDeviceAttr , cudaError_t
18
+ from cuda .bindings .runtime import (
19
+ cudaDeviceAttr ,
20
+ cudaError_t ,
21
+ )
19
22
20
23
from rmm ._cuda .gpu import (
21
24
CUDARuntimeError ,
@@ -30,7 +33,6 @@ def validate_setup():
30
33
31
34
notify_caller_errors = {
32
35
cudaError_t .cudaErrorInitializationError ,
33
- cudaError_t .cudaErrorInsufficientDriver ,
34
36
cudaError_t .cudaErrorInvalidDeviceFunction ,
35
37
cudaError_t .cudaErrorInvalidDevice ,
36
38
cudaError_t .cudaErrorStartupFailure ,
@@ -53,12 +55,23 @@ def validate_setup():
53
55
except CUDARuntimeError as e :
54
56
if e .status in notify_caller_errors :
55
57
raise e
58
+
59
+ # We must distinguish between "CPU only" and "the driver is
60
+ # insufficient for the runtime".
61
+ if e .status == cudaError_t .cudaErrorInsufficientDriver :
62
+ # cudaDriverGetVersion() returns 0 when ``libcuda.so`` is
63
+ # missing. Otherwise there is a CUDA driver but it is
64
+ # insufficient for the runtime, so we re-raise the original
65
+ # exception
66
+ if driverGetVersion () != 0 :
67
+ raise e
68
+
56
69
# If there is no GPU detected, set `gpus_count` to -1
57
70
gpus_count = - 1
58
71
except RuntimeError as e :
59
- # getDeviceCount() can raise a RuntimeError
60
- # when ``libcuda.so`` is missing.
61
- # We don't want this to propagate up to the user.
72
+ # When using cuda-python < 12.9, getDeviceCount() can raise a
73
+ # RuntimeError if ``libcuda.so`` is missing. We don't want this to
74
+ # propagate up to the user.
62
75
warnings .warn (str (e ))
63
76
return
64
77
0 commit comments