Skip to content

Commit ae8917f

Browse files
enh: Enrich nvutils (#72) (#3336)
Co-authored-by: qianduoduo0904 <109654808+qianduoduo0904@users.noreply.github.com>
1 parent 277bae7 commit ae8917f

File tree

2 files changed

+332
-6
lines changed

2 files changed

+332
-6
lines changed

mars/lib/nvutils.py

Lines changed: 294 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16+
import dataclasses
1617
import logging
1718
import os
1819
import subprocess
@@ -31,6 +32,7 @@
3132
POINTER,
3233
CDLL,
3334
)
35+
from typing import List, Tuple, Optional, Union
3436

3537
from ..utils import parse_readable_size
3638

@@ -49,9 +51,37 @@
4951

5052
# nvml constants
5153
NVML_SUCCESS = 0
54+
NVML_ERROR_UNINITIALIZED = 1
55+
NVML_ERROR_INVALID_ARGUMENT = 2
56+
NVML_ERROR_NOT_SUPPORTED = 3
57+
NVML_ERROR_NO_PERMISSION = 4
58+
NVML_ERROR_ALREADY_INITIALIZED = 5
59+
NVML_ERROR_NOT_FOUND = 6
60+
NVML_ERROR_INSUFFICIENT_SIZE = 7
61+
NVML_ERROR_INSUFFICIENT_POWER = 8
62+
NVML_ERROR_DRIVER_NOT_LOADED = 9
63+
NVML_ERROR_TIMEOUT = 10
64+
NVML_ERROR_IRQ_ISSUE = 11
65+
NVML_ERROR_LIBRARY_NOT_FOUND = 12
66+
NVML_ERROR_FUNCTION_NOT_FOUND = 13
67+
NVML_ERROR_CORRUPTED_INFOROM = 14
68+
NVML_ERROR_GPU_IS_LOST = 15
69+
NVML_ERROR_RESET_REQUIRED = 16
70+
NVML_ERROR_OPERATING_SYSTEM = 17
71+
NVML_ERROR_LIB_RM_VERSION_MISMATCH = 18
72+
NVML_ERROR_IN_USE = 19
73+
NVML_ERROR_MEMORY = 20
74+
NVML_ERROR_NO_DATA = 21
75+
NVML_ERROR_VGPU_ECC_NOT_SUPPORTED = 22
76+
NVML_ERROR_INSUFFICIENT_RESOURCES = 23
77+
NVML_ERROR_FREQ_NOT_SUPPORTED = 24
78+
NVML_ERROR_UNKNOWN = 999
5279
NVML_TEMPERATURE_GPU = 0
53-
5480
NVML_DRIVER_NOT_LOADED = 9
81+
NVML_DEVICE_UUID_V2_BUFFER_SIZE = 96
82+
NVML_VALUE_NOT_AVAILABLE_ulonglong = c_ulonglong(-1)
83+
NVML_DEVICE_MIG_DISABLE = 0x0
84+
NVML_DEVICE_MIG_ENABLE = 0x1
5585

5686

5787
class _CUuuid_t(Structure):
@@ -80,6 +110,52 @@ class _nvmlBAR1Memory_t(Structure):
80110
]
81111

82112

113+
class _nvmlProcessInfo_t(Structure):
114+
_fields_ = [
115+
("pid", c_uint),
116+
("usedGpuMemory", c_ulonglong),
117+
("gpuInstanceId", c_uint),
118+
("computeInstanceId", c_uint),
119+
]
120+
121+
122+
## Alternative object
123+
# Allows the object to be printed
124+
# Allows mismatched types to be assigned
125+
# - like None when the Structure variant requires c_uint
126+
class nvmlFriendlyObject:
127+
def __init__(self, dictionary):
128+
for x in dictionary:
129+
setattr(self, x, dictionary[x])
130+
131+
def __str__(self):
132+
return self.__dict__.__str__()
133+
134+
135+
def nvmlStructToFriendlyObject(struct):
136+
d = {}
137+
for x in struct._fields_:
138+
key = x[0]
139+
value = getattr(struct, key)
140+
# only need to convert from bytes if bytes, no need to check python version.
141+
d[key] = value.decode() if isinstance(value, bytes) else value
142+
obj = nvmlFriendlyObject(d)
143+
return obj
144+
145+
146+
@dataclasses.dataclass
147+
class CudaDeviceInfo:
148+
uuid: bytes = None
149+
device_index: int = None
150+
mig_index: int = None
151+
152+
153+
@dataclasses.dataclass
154+
class CudaContext:
155+
has_context: bool
156+
device_info: CudaDeviceInfo = None
157+
158+
83159
_is_windows: bool = sys.platform.startswith("win")
84160
_is_wsl: bool = "WSL_DISTRO_NAME" in os.environ
85161

@@ -247,7 +323,7 @@ def _init():
247323
_init_pid = os.getpid()
248324

249325

250-
def get_device_count():
326+
def get_device_count() -> int:
251327
global _gpu_count
252328

253329
if _gpu_count is not None:
@@ -259,7 +335,7 @@ def get_device_count():
259335

260336
if "CUDA_VISIBLE_DEVICES" in os.environ:
261337
devices = os.environ["CUDA_VISIBLE_DEVICES"].strip()
262-
if not devices:
338+
if not devices or devices == "-1":
263339
_gpu_count = 0
264340
else:
265341
_gpu_count = len(devices.split(","))
@@ -270,7 +346,17 @@ def get_device_count():
270346
return _gpu_count
271347

272348

273-
def get_driver_info():
349+
def _get_all_device_count() -> int:
350+
_init_nvml()
351+
if _nvml_lib is None:
352+
return None
353+
354+
n_gpus = c_uint()
355+
_cu_check_error(_nvml_lib.nvmlDeviceGetCount(byref(n_gpus)))
356+
return n_gpus.value
357+
358+
359+
def get_driver_info() -> _nvml_driver_info:
274360
global _driver_info
275361

276362
_init_nvml()
@@ -294,7 +380,7 @@ def get_driver_info():
294380
return _driver_info
295381

296382

297-
def get_device_info(dev_index):
383+
def get_device_info(dev_index: int) -> _cu_device_info:
298384
try:
299385
return _device_infos[dev_index]
300386
except KeyError:
@@ -350,7 +436,7 @@ def get_device_info(dev_index):
350436
return info
351437

352438

353-
def get_device_status(dev_index):
439+
def get_device_status(dev_index: int) -> _nvml_device_status:
354440
_init()
355441
if _init_pid is None:
356442
return None
@@ -424,3 +510,205 @@ def get_device_status(dev_index):
424510
fb_free_mem=fb_free_mem,
425511
fb_used_mem=fb_used_mem,
426512
)
513+
514+
515+
def get_handle_by_index(index: int) -> _nvmlDevice_t:
516+
_init_nvml()
517+
if _nvml_lib is None:
518+
return None
519+
520+
c_index = c_int(index)
521+
device = _nvmlDevice_t()
522+
_nvml_check_error(_nvml_lib.nvmlDeviceGetHandleByIndex_v2(c_index, byref(device)))
523+
return device
524+
525+
526+
def get_handle_by_uuid(uuid: bytes) -> _nvmlDevice_t:
527+
_init_nvml()
528+
if _nvml_lib is None:
529+
return None
530+
531+
c_uuid = c_char_p(uuid)
532+
device = _nvmlDevice_t()
533+
_nvml_check_error(_nvml_lib.nvmlDeviceGetHandleByUUID(c_uuid, byref(device)))
534+
return device
535+
536+
537+
def get_mig_mode(device: _nvmlDevice_t) -> Tuple[int, int]:
538+
_init_nvml()
539+
if _nvml_lib is None:
540+
return None
541+
542+
c_current_mode, c_pending_mode = c_uint(), c_uint()
543+
_nvml_check_error(
544+
_nvml_lib.nvmlDeviceGetMigMode(
545+
device, byref(c_current_mode), byref(c_pending_mode)
546+
)
547+
)
548+
return c_current_mode.value, c_pending_mode.value
549+
550+
551+
def get_max_mig_device_count(device: _nvmlDevice_t) -> int:
552+
_init_nvml()
553+
if _nvml_lib is None:
554+
return None
555+
556+
c_count = c_uint()
557+
_nvml_check_error(_nvml_lib.nvmlDeviceGetMaxMigDeviceCount(device, byref(c_count)))
558+
return c_count.value
559+
560+
561+
def get_mig_device_handle_by_index(device: _nvmlDevice_t, index: int) -> _nvmlDevice_t:
562+
_init_nvml()
563+
if _nvml_lib is None:
564+
return None
565+
566+
c_index = c_uint(index)
567+
mig_device = _nvmlDevice_t()
568+
_nvml_check_error(
569+
_nvml_lib.nvmlDeviceGetMigDeviceHandleByIndex(
570+
device, c_index, byref(mig_device)
571+
)
572+
)
573+
return mig_device
574+
575+
576+
def get_index(handle: _nvmlDevice_t) -> int:
577+
_init_nvml()
578+
if _nvml_lib is None:
579+
return None
580+
581+
c_index = c_uint()
582+
_nvml_check_error(_nvml_lib.nvmlDeviceGetIndex(handle, byref(c_index)))
583+
return c_index.value
584+
585+
586+
def get_uuid(handle: _nvmlDevice_t) -> bytes:
587+
_init_nvml()
588+
if _nvml_lib is None:
589+
return None
590+
591+
c_uuid = create_string_buffer(NVML_DEVICE_UUID_V2_BUFFER_SIZE)
592+
_nvml_check_error(
593+
_nvml_lib.nvmlDeviceGetUUID(
594+
handle, c_uuid, c_uint(NVML_DEVICE_UUID_V2_BUFFER_SIZE)
595+
)
596+
)
597+
return c_uuid.value
598+
599+
600+
def get_index_and_uuid(device: Union[int, bytes, str]) -> CudaDeviceInfo:
601+
_init_nvml()
602+
if _nvml_lib is None:
603+
return None
604+
605+
try:
606+
device_index = int(device)
607+
device_handle = get_handle_by_index(device_index)
608+
uuid = get_uuid(device_handle)
609+
except ValueError:
610+
uuid = device if isinstance(device, bytes) else device.encode()
611+
uuid_handle = get_handle_by_uuid(uuid)
612+
device_index = get_index(uuid_handle)
613+
uuid = get_uuid(uuid_handle)
614+
615+
return CudaDeviceInfo(uuid=uuid, device_index=device_index)
616+
617+
618+
def get_compute_running_processes(handle: _nvmlDevice_t) -> List[nvmlFriendlyObject]:
619+
_init_nvml()
620+
if _nvml_lib is None:
621+
return None
622+
623+
c_count = c_uint(0)
624+
func = getattr(_nvml_lib, "nvmlDeviceGetComputeRunningProcesses_v3", None)
625+
if func is None:
626+
func = getattr(_nvml_lib, "nvmlDeviceGetComputeRunningProcesses_v2")
627+
ret = func(handle, byref(c_count), None)
628+
629+
if ret == NVML_SUCCESS:
630+
# special case, no running processes
631+
return []
632+
elif ret == NVML_ERROR_INSUFFICIENT_SIZE:
633+
# typical case
634+
# oversize the array incase more processes are created
635+
c_count.value = c_count.value * 2 + 5
636+
proc_array = _nvmlProcessInfo_t * c_count.value
637+
c_procs = proc_array()
638+
639+
_nvml_check_error(func(handle, byref(c_count), c_procs))
640+
641+
procs = []
642+
for i in range(c_count.value):
643+
# use an alternative struct for this object
644+
obj = nvmlStructToFriendlyObject(c_procs[i])
645+
if obj.usedGpuMemory == NVML_VALUE_NOT_AVAILABLE_ulonglong.value:
646+
# special case for WDDM on Windows, see comment above
647+
obj.usedGpuMemory = None
648+
procs.append(obj)
649+
650+
return procs
651+
else:
652+
# error case
653+
_nvml_check_error(ret)
654+
655+
656+
def _running_process_matches(handle: _nvmlDevice_t) -> bool:
657+
"""Check whether the current process is same as that of handle
658+
Parameters
659+
----------
660+
handle : _nvmlDevice_t
661+
NVML handle to CUDA device
662+
Returns
663+
-------
664+
out : bool
665+
Whether the device handle has a CUDA context on the running process.
666+
"""
667+
return any(os.getpid() == o.pid for o in get_compute_running_processes(handle))
668+
669+
670+
def get_cuda_context() -> CudaContext:
671+
"""Check whether the current process already has a CUDA context created."""
672+
673+
_init()
674+
if _init_pid is None:
675+
return CudaContext(has_context=False)
676+
677+
for index in range(_get_all_device_count()):
678+
handle = get_handle_by_index(index)
679+
try:
680+
mig_current_mode, mig_pending_mode = get_mig_mode(handle)
681+
except NVMLAPIError as e:
682+
if e.errno == NVML_ERROR_NOT_SUPPORTED:
683+
mig_current_mode = NVML_DEVICE_MIG_DISABLE
684+
else:
685+
raise
686+
if mig_current_mode == NVML_DEVICE_MIG_ENABLE:
687+
for mig_index in range(get_max_mig_device_count(handle)):
688+
try:
689+
mig_handle = get_mig_device_handle_by_index(handle, mig_index)
690+
except NVMLAPIError as e:
691+
if e.errno == NVML_ERROR_NOT_FOUND:
692+
# No MIG device with that index
693+
continue
694+
else:
695+
raise
696+
if _running_process_matches(mig_handle):
697+
return CudaContext(
698+
has_context=True,
699+
device_info=CudaDeviceInfo(
700+
uuid=get_uuid(handle),
701+
device_index=index,
702+
mig_index=mig_index,
703+
),
704+
)
705+
else:
706+
if _running_process_matches(handle):
707+
return CudaContext(
708+
has_context=True,
709+
device_info=CudaDeviceInfo(
710+
uuid=get_uuid(handle), device_index=index
711+
),
712+
)
713+
714+
return CudaContext(has_context=False)

0 commit comments

Comments
 (0)