13
13
# See the License for the specific language governing permissions and
14
14
# limitations under the License.
15
15
16
+ import dataclasses
16
17
import logging
17
18
import os
18
19
import subprocess
31
32
POINTER ,
32
33
CDLL ,
33
34
)
35
+ from typing import List , Tuple , Optional , Union
34
36
35
37
from ..utils import parse_readable_size
36
38
49
51
50
52
# nvml constants
51
53
NVML_SUCCESS = 0
54
+ NVML_ERROR_UNINITIALIZED = 1
55
+ NVML_ERROR_INVALID_ARGUMENT = 2
56
+ NVML_ERROR_NOT_SUPPORTED = 3
57
+ NVML_ERROR_NO_PERMISSION = 4
58
+ NVML_ERROR_ALREADY_INITIALIZED = 5
59
+ NVML_ERROR_NOT_FOUND = 6
60
+ NVML_ERROR_INSUFFICIENT_SIZE = 7
61
+ NVML_ERROR_INSUFFICIENT_POWER = 8
62
+ NVML_ERROR_DRIVER_NOT_LOADED = 9
63
+ NVML_ERROR_TIMEOUT = 10
64
+ NVML_ERROR_IRQ_ISSUE = 11
65
+ NVML_ERROR_LIBRARY_NOT_FOUND = 12
66
+ NVML_ERROR_FUNCTION_NOT_FOUND = 13
67
+ NVML_ERROR_CORRUPTED_INFOROM = 14
68
+ NVML_ERROR_GPU_IS_LOST = 15
69
+ NVML_ERROR_RESET_REQUIRED = 16
70
+ NVML_ERROR_OPERATING_SYSTEM = 17
71
+ NVML_ERROR_LIB_RM_VERSION_MISMATCH = 18
72
+ NVML_ERROR_IN_USE = 19
73
+ NVML_ERROR_MEMORY = 20
74
+ NVML_ERROR_NO_DATA = 21
75
+ NVML_ERROR_VGPU_ECC_NOT_SUPPORTED = 22
76
+ NVML_ERROR_INSUFFICIENT_RESOURCES = 23
77
+ NVML_ERROR_FREQ_NOT_SUPPORTED = 24
78
+ NVML_ERROR_UNKNOWN = 999
52
79
NVML_TEMPERATURE_GPU = 0
53
-
54
80
NVML_DRIVER_NOT_LOADED = 9
81
+ NVML_DEVICE_UUID_V2_BUFFER_SIZE = 96
82
+ NVML_VALUE_NOT_AVAILABLE_ulonglong = c_ulonglong (- 1 )
83
+ NVML_DEVICE_MIG_DISABLE = 0x0
84
+ NVML_DEVICE_MIG_ENABLE = 0x1
55
85
56
86
57
87
class _CUuuid_t (Structure ):
@@ -80,6 +110,52 @@ class _nvmlBAR1Memory_t(Structure):
80
110
]
81
111
82
112
113
+ class _nvmlProcessInfo_t (Structure ):
114
+ _fields_ = [
115
+ ("pid" , c_uint ),
116
+ ("usedGpuMemory" , c_ulonglong ),
117
+ ("gpuInstanceId" , c_uint ),
118
+ ("computeInstanceId" , c_uint ),
119
+ ]
120
+
121
+
122
+ ## Alternative object
123
+ # Allows the object to be printed
124
+ # Allows mismatched types to be assigned
125
+ # - like None when the Structure variant requires c_uint
126
+ class nvmlFriendlyObject :
127
+ def __init__ (self , dictionary ):
128
+ for x in dictionary :
129
+ setattr (self , x , dictionary [x ])
130
+
131
+ def __str__ (self ):
132
+ return self .__dict__ .__str__ ()
133
+
134
+
135
+ def nvmlStructToFriendlyObject (struct ):
136
+ d = {}
137
+ for x in struct ._fields_ :
138
+ key = x [0 ]
139
+ value = getattr (struct , key )
140
+ # only need to convert from bytes if bytes, no need to check python version.
141
+ d [key ] = value .decode () if isinstance (value , bytes ) else value
142
+ obj = nvmlFriendlyObject (d )
143
+ return obj
144
+
145
+
146
+ @dataclasses .dataclass
147
+ class CudaDeviceInfo :
148
+ uuid : bytes = None
149
+ device_index : int = None
150
+ mig_index : int = None
151
+
152
+
153
+ @dataclasses .dataclass
154
+ class CudaContext :
155
+ has_context : bool
156
+ device_info : CudaDeviceInfo = None
157
+
158
+
83
159
_is_windows : bool = sys .platform .startswith ("win" )
84
160
_is_wsl : bool = "WSL_DISTRO_NAME" in os .environ
85
161
@@ -247,7 +323,7 @@ def _init():
247
323
_init_pid = os .getpid ()
248
324
249
325
250
- def get_device_count ():
326
+ def get_device_count () -> int :
251
327
global _gpu_count
252
328
253
329
if _gpu_count is not None :
@@ -259,7 +335,7 @@ def get_device_count():
259
335
260
336
if "CUDA_VISIBLE_DEVICES" in os .environ :
261
337
devices = os .environ ["CUDA_VISIBLE_DEVICES" ].strip ()
262
- if not devices :
338
+ if not devices or devices == "-1" :
263
339
_gpu_count = 0
264
340
else :
265
341
_gpu_count = len (devices .split ("," ))
@@ -270,7 +346,17 @@ def get_device_count():
270
346
return _gpu_count
271
347
272
348
273
- def get_driver_info ():
349
+ def _get_all_device_count () -> int :
350
+ _init_nvml ()
351
+ if _nvml_lib is None :
352
+ return None
353
+
354
+ n_gpus = c_uint ()
355
+ _cu_check_error (_nvml_lib .nvmlDeviceGetCount (byref (n_gpus )))
356
+ return n_gpus .value
357
+
358
+
359
+ def get_driver_info () -> _nvml_driver_info :
274
360
global _driver_info
275
361
276
362
_init_nvml ()
@@ -294,7 +380,7 @@ def get_driver_info():
294
380
return _driver_info
295
381
296
382
297
- def get_device_info (dev_index ) :
383
+ def get_device_info (dev_index : int ) -> _cu_device_info :
298
384
try :
299
385
return _device_infos [dev_index ]
300
386
except KeyError :
@@ -350,7 +436,7 @@ def get_device_info(dev_index):
350
436
return info
351
437
352
438
353
- def get_device_status (dev_index ) :
439
+ def get_device_status (dev_index : int ) -> _nvml_device_status :
354
440
_init ()
355
441
if _init_pid is None :
356
442
return None
@@ -424,3 +510,205 @@ def get_device_status(dev_index):
424
510
fb_free_mem = fb_free_mem ,
425
511
fb_used_mem = fb_used_mem ,
426
512
)
513
+
514
+
515
+ def get_handle_by_index (index : int ) -> _nvmlDevice_t :
516
+ _init_nvml ()
517
+ if _nvml_lib is None :
518
+ return None
519
+
520
+ c_index = c_int (index )
521
+ device = _nvmlDevice_t ()
522
+ _nvml_check_error (_nvml_lib .nvmlDeviceGetHandleByIndex_v2 (c_index , byref (device )))
523
+ return device
524
+
525
+
526
+ def get_handle_by_uuid (uuid : bytes ) -> _nvmlDevice_t :
527
+ _init_nvml ()
528
+ if _nvml_lib is None :
529
+ return None
530
+
531
+ c_uuid = c_char_p (uuid )
532
+ device = _nvmlDevice_t ()
533
+ _nvml_check_error (_nvml_lib .nvmlDeviceGetHandleByUUID (c_uuid , byref (device )))
534
+ return device
535
+
536
+
537
+ def get_mig_mode (device : _nvmlDevice_t ) -> Tuple [int , int ]:
538
+ _init_nvml ()
539
+ if _nvml_lib is None :
540
+ return None
541
+
542
+ c_current_mode , c_pending_mode = c_uint (), c_uint ()
543
+ _nvml_check_error (
544
+ _nvml_lib .nvmlDeviceGetMigMode (
545
+ device , byref (c_current_mode ), byref (c_pending_mode )
546
+ )
547
+ )
548
+ return c_current_mode .value , c_pending_mode .value
549
+
550
+
551
+ def get_max_mig_device_count (device : _nvmlDevice_t ) -> int :
552
+ _init_nvml ()
553
+ if _nvml_lib is None :
554
+ return None
555
+
556
+ c_count = c_uint ()
557
+ _nvml_check_error (_nvml_lib .nvmlDeviceGetMaxMigDeviceCount (device , byref (c_count )))
558
+ return c_count .value
559
+
560
+
561
+ def get_mig_device_handle_by_index (device : _nvmlDevice_t , index : int ) -> _nvmlDevice_t :
562
+ _init_nvml ()
563
+ if _nvml_lib is None :
564
+ return None
565
+
566
+ c_index = c_uint (index )
567
+ mig_device = _nvmlDevice_t ()
568
+ _nvml_check_error (
569
+ _nvml_lib .nvmlDeviceGetMigDeviceHandleByIndex (
570
+ device , c_index , byref (mig_device )
571
+ )
572
+ )
573
+ return mig_device
574
+
575
+
576
+ def get_index (handle : _nvmlDevice_t ) -> int :
577
+ _init_nvml ()
578
+ if _nvml_lib is None :
579
+ return None
580
+
581
+ c_index = c_uint ()
582
+ _nvml_check_error (_nvml_lib .nvmlDeviceGetIndex (handle , byref (c_index )))
583
+ return c_index .value
584
+
585
+
586
+ def get_uuid (handle : _nvmlDevice_t ) -> bytes :
587
+ _init_nvml ()
588
+ if _nvml_lib is None :
589
+ return None
590
+
591
+ c_uuid = create_string_buffer (NVML_DEVICE_UUID_V2_BUFFER_SIZE )
592
+ _nvml_check_error (
593
+ _nvml_lib .nvmlDeviceGetUUID (
594
+ handle , c_uuid , c_uint (NVML_DEVICE_UUID_V2_BUFFER_SIZE )
595
+ )
596
+ )
597
+ return c_uuid .value
598
+
599
+
600
+ def get_index_and_uuid (device : Union [int , bytes , str ]) -> CudaDeviceInfo :
601
+ _init_nvml ()
602
+ if _nvml_lib is None :
603
+ return None
604
+
605
+ try :
606
+ device_index = int (device )
607
+ device_handle = get_handle_by_index (device_index )
608
+ uuid = get_uuid (device_handle )
609
+ except ValueError :
610
+ uuid = device if isinstance (device , bytes ) else device .encode ()
611
+ uuid_handle = get_handle_by_uuid (uuid )
612
+ device_index = get_index (uuid_handle )
613
+ uuid = get_uuid (uuid_handle )
614
+
615
+ return CudaDeviceInfo (uuid = uuid , device_index = device_index )
616
+
617
+
618
+ def get_compute_running_processes (handle : _nvmlDevice_t ) -> List [nvmlFriendlyObject ]:
619
+ _init_nvml ()
620
+ if _nvml_lib is None :
621
+ return None
622
+
623
+ c_count = c_uint (0 )
624
+ func = getattr (_nvml_lib , "nvmlDeviceGetComputeRunningProcesses_v3" , None )
625
+ if func is None :
626
+ func = getattr (_nvml_lib , "nvmlDeviceGetComputeRunningProcesses_v2" )
627
+ ret = func (handle , byref (c_count ), None )
628
+
629
+ if ret == NVML_SUCCESS :
630
+ # special case, no running processes
631
+ return []
632
+ elif ret == NVML_ERROR_INSUFFICIENT_SIZE :
633
+ # typical case
634
+ # oversize the array incase more processes are created
635
+ c_count .value = c_count .value * 2 + 5
636
+ proc_array = _nvmlProcessInfo_t * c_count .value
637
+ c_procs = proc_array ()
638
+
639
+ _nvml_check_error (func (handle , byref (c_count ), c_procs ))
640
+
641
+ procs = []
642
+ for i in range (c_count .value ):
643
+ # use an alternative struct for this object
644
+ obj = nvmlStructToFriendlyObject (c_procs [i ])
645
+ if obj .usedGpuMemory == NVML_VALUE_NOT_AVAILABLE_ulonglong .value :
646
+ # special case for WDDM on Windows, see comment above
647
+ obj .usedGpuMemory = None
648
+ procs .append (obj )
649
+
650
+ return procs
651
+ else :
652
+ # error case
653
+ _nvml_check_error (ret )
654
+
655
+
656
+ def _running_process_matches (handle : _nvmlDevice_t ) -> bool :
657
+ """Check whether the current process is same as that of handle
658
+ Parameters
659
+ ----------
660
+ handle : _nvmlDevice_t
661
+ NVML handle to CUDA device
662
+ Returns
663
+ -------
664
+ out : bool
665
+ Whether the device handle has a CUDA context on the running process.
666
+ """
667
+ return any (os .getpid () == o .pid for o in get_compute_running_processes (handle ))
668
+
669
+
670
+ def get_cuda_context () -> CudaContext :
671
+ """Check whether the current process already has a CUDA context created."""
672
+
673
+ _init ()
674
+ if _init_pid is None :
675
+ return CudaContext (has_context = False )
676
+
677
+ for index in range (_get_all_device_count ()):
678
+ handle = get_handle_by_index (index )
679
+ try :
680
+ mig_current_mode , mig_pending_mode = get_mig_mode (handle )
681
+ except NVMLAPIError as e :
682
+ if e .errno == NVML_ERROR_NOT_SUPPORTED :
683
+ mig_current_mode = NVML_DEVICE_MIG_DISABLE
684
+ else :
685
+ raise
686
+ if mig_current_mode == NVML_DEVICE_MIG_ENABLE :
687
+ for mig_index in range (get_max_mig_device_count (handle )):
688
+ try :
689
+ mig_handle = get_mig_device_handle_by_index (handle , mig_index )
690
+ except NVMLAPIError as e :
691
+ if e .errno == NVML_ERROR_NOT_FOUND :
692
+ # No MIG device with that index
693
+ continue
694
+ else :
695
+ raise
696
+ if _running_process_matches (mig_handle ):
697
+ return CudaContext (
698
+ has_context = True ,
699
+ device_info = CudaDeviceInfo (
700
+ uuid = get_uuid (handle ),
701
+ device_index = index ,
702
+ mig_index = mig_index ,
703
+ ),
704
+ )
705
+ else :
706
+ if _running_process_matches (handle ):
707
+ return CudaContext (
708
+ has_context = True ,
709
+ device_info = CudaDeviceInfo (
710
+ uuid = get_uuid (handle ), device_index = index
711
+ ),
712
+ )
713
+
714
+ return CudaContext (has_context = False )
0 commit comments