31
31
sys .path .append ("../common" )
32
32
33
33
import os
34
+ import time
34
35
import unittest
36
+ from functools import partial
35
37
36
38
import infer_util as iu
37
39
import numpy as np
38
40
import test_util as tu
39
41
import tritonclient .grpc as grpcclient
40
42
import tritonclient .http as httpclient
41
- import tritonshmutils .cuda_shared_memory as cshm
43
+ import tritonclient . utils .cuda_shared_memory as cshm
42
44
from tritonclient .utils import *
43
45
44
46
45
- class CudaSharedMemoryTest (tu .TestResultCollector ):
47
+ class CudaSharedMemoryTestBase (tu .TestResultCollector ):
46
48
DEFAULT_SHM_BYTE_SIZE = 64
47
49
48
50
def setUp (self ):
@@ -61,76 +63,6 @@ def _setup_client(self):
61
63
self .url , verbose = True
62
64
)
63
65
64
- def test_invalid_create_shm (self ):
65
- # Raises error since tried to create invalid cuda shared memory region
66
- try :
67
- shm_op0_handle = cshm .create_shared_memory_region ("dummy_data" , - 1 , 0 )
68
- cshm .destroy_shared_memory_region (shm_op0_handle )
69
- except Exception as ex :
70
- self .assertEqual (str (ex ), "unable to create cuda shared memory handle" )
71
-
72
- def test_valid_create_set_register (self ):
73
- # Create a valid cuda shared memory region, fill data in it and register
74
- shm_op0_handle = cshm .create_shared_memory_region ("dummy_data" , 8 , 0 )
75
- cshm .set_shared_memory_region (
76
- shm_op0_handle , [np .array ([1 , 2 ], dtype = np .float32 )]
77
- )
78
- self .triton_client .register_cuda_shared_memory (
79
- "dummy_data" , cshm .get_raw_handle (shm_op0_handle ), 0 , 8
80
- )
81
- shm_status = self .triton_client .get_cuda_shared_memory_status ()
82
- if self .protocol == "http" :
83
- self .assertEqual (len (shm_status ), 1 )
84
- else :
85
- self .assertEqual (len (shm_status .regions ), 1 )
86
- cshm .destroy_shared_memory_region (shm_op0_handle )
87
-
88
- def test_unregister_before_register (self ):
89
- # Create a valid cuda shared memory region and unregister before register
90
- shm_op0_handle = cshm .create_shared_memory_region ("dummy_data" , 8 , 0 )
91
- self .triton_client .unregister_cuda_shared_memory ("dummy_data" )
92
- shm_status = self .triton_client .get_cuda_shared_memory_status ()
93
- if self .protocol == "http" :
94
- self .assertEqual (len (shm_status ), 0 )
95
- else :
96
- self .assertEqual (len (shm_status .regions ), 0 )
97
- cshm .destroy_shared_memory_region (shm_op0_handle )
98
-
99
- def test_unregister_after_register (self ):
100
- # Create a valid cuda shared memory region and unregister after register
101
- shm_op0_handle = cshm .create_shared_memory_region ("dummy_data" , 8 , 0 )
102
- self .triton_client .register_cuda_shared_memory (
103
- "dummy_data" , cshm .get_raw_handle (shm_op0_handle ), 0 , 8
104
- )
105
- self .triton_client .unregister_cuda_shared_memory ("dummy_data" )
106
- shm_status = self .triton_client .get_cuda_shared_memory_status ()
107
- if self .protocol == "http" :
108
- self .assertEqual (len (shm_status ), 0 )
109
- else :
110
- self .assertEqual (len (shm_status .regions ), 0 )
111
- cshm .destroy_shared_memory_region (shm_op0_handle )
112
-
113
- def test_reregister_after_register (self ):
114
- # Create a valid cuda shared memory region and unregister after register
115
- shm_op0_handle = cshm .create_shared_memory_region ("dummy_data" , 8 , 0 )
116
- self .triton_client .register_cuda_shared_memory (
117
- "dummy_data" , cshm .get_raw_handle (shm_op0_handle ), 0 , 8
118
- )
119
- try :
120
- self .triton_client .register_cuda_shared_memory (
121
- "dummy_data" , cshm .get_raw_handle (shm_op0_handle ), 0 , 8
122
- )
123
- except Exception as ex :
124
- self .assertIn (
125
- "shared memory region 'dummy_data' already in manager" , str (ex )
126
- )
127
- shm_status = self .triton_client .get_cuda_shared_memory_status ()
128
- if self .protocol == "http" :
129
- self .assertEqual (len (shm_status ), 1 )
130
- else :
131
- self .assertEqual (len (shm_status .regions ), 1 )
132
- cshm .destroy_shared_memory_region (shm_op0_handle )
133
-
134
66
def _configure_server (
135
67
self ,
136
68
create_byte_size = DEFAULT_SHM_BYTE_SIZE ,
@@ -205,6 +137,78 @@ def _cleanup_server(self, shm_handles):
205
137
for shm_handle in shm_handles :
206
138
cshm .destroy_shared_memory_region (shm_handle )
207
139
140
+
141
+ class CudaSharedMemoryTest (CudaSharedMemoryTestBase ):
142
+ def test_invalid_create_shm (self ):
143
+ # Raises error since tried to create invalid cuda shared memory region
144
+ try :
145
+ shm_op0_handle = cshm .create_shared_memory_region ("dummy_data" , - 1 , 0 )
146
+ cshm .destroy_shared_memory_region (shm_op0_handle )
147
+ except Exception as ex :
148
+ self .assertEqual (str (ex ), "unable to create cuda shared memory handle" )
149
+
150
+ def test_valid_create_set_register (self ):
151
+ # Create a valid cuda shared memory region, fill data in it and register
152
+ shm_op0_handle = cshm .create_shared_memory_region ("dummy_data" , 8 , 0 )
153
+ cshm .set_shared_memory_region (
154
+ shm_op0_handle , [np .array ([1 , 2 ], dtype = np .float32 )]
155
+ )
156
+ self .triton_client .register_cuda_shared_memory (
157
+ "dummy_data" , cshm .get_raw_handle (shm_op0_handle ), 0 , 8
158
+ )
159
+ shm_status = self .triton_client .get_cuda_shared_memory_status ()
160
+ if self .protocol == "http" :
161
+ self .assertEqual (len (shm_status ), 1 )
162
+ else :
163
+ self .assertEqual (len (shm_status .regions ), 1 )
164
+ cshm .destroy_shared_memory_region (shm_op0_handle )
165
+
166
+ def test_unregister_before_register (self ):
167
+ # Create a valid cuda shared memory region and unregister before register
168
+ shm_op0_handle = cshm .create_shared_memory_region ("dummy_data" , 8 , 0 )
169
+ self .triton_client .unregister_cuda_shared_memory ("dummy_data" )
170
+ shm_status = self .triton_client .get_cuda_shared_memory_status ()
171
+ if self .protocol == "http" :
172
+ self .assertEqual (len (shm_status ), 0 )
173
+ else :
174
+ self .assertEqual (len (shm_status .regions ), 0 )
175
+ cshm .destroy_shared_memory_region (shm_op0_handle )
176
+
177
+ def test_unregister_after_register (self ):
178
+ # Create a valid cuda shared memory region and unregister after register
179
+ shm_op0_handle = cshm .create_shared_memory_region ("dummy_data" , 8 , 0 )
180
+ self .triton_client .register_cuda_shared_memory (
181
+ "dummy_data" , cshm .get_raw_handle (shm_op0_handle ), 0 , 8
182
+ )
183
+ self .triton_client .unregister_cuda_shared_memory ("dummy_data" )
184
+ shm_status = self .triton_client .get_cuda_shared_memory_status ()
185
+ if self .protocol == "http" :
186
+ self .assertEqual (len (shm_status ), 0 )
187
+ else :
188
+ self .assertEqual (len (shm_status .regions ), 0 )
189
+ cshm .destroy_shared_memory_region (shm_op0_handle )
190
+
191
+ def test_reregister_after_register (self ):
192
+ # Create a valid cuda shared memory region and unregister after register
193
+ shm_op0_handle = cshm .create_shared_memory_region ("dummy_data" , 8 , 0 )
194
+ self .triton_client .register_cuda_shared_memory (
195
+ "dummy_data" , cshm .get_raw_handle (shm_op0_handle ), 0 , 8
196
+ )
197
+ try :
198
+ self .triton_client .register_cuda_shared_memory (
199
+ "dummy_data" , cshm .get_raw_handle (shm_op0_handle ), 0 , 8
200
+ )
201
+ except Exception as ex :
202
+ self .assertIn (
203
+ "shared memory region 'dummy_data' already in manager" , str (ex )
204
+ )
205
+ shm_status = self .triton_client .get_cuda_shared_memory_status ()
206
+ if self .protocol == "http" :
207
+ self .assertEqual (len (shm_status ), 1 )
208
+ else :
209
+ self .assertEqual (len (shm_status .regions ), 1 )
210
+ cshm .destroy_shared_memory_region (shm_op0_handle )
211
+
208
212
def test_unregister_after_inference (self ):
209
213
# Unregister after inference
210
214
error_msg = []
@@ -396,5 +400,169 @@ def test_infer_byte_size_out_of_bound(self):
396
400
self ._cleanup_server (shm_handles )
397
401
398
402
403
+ class TestCudaSharedMemoryUnregister (CudaSharedMemoryTestBase ):
404
+ def _test_unregister_shm_fail (self ):
405
+ second_client = httpclient .InferenceServerClient ("localhost:8000" , verbose = True )
406
+
407
+ with self .assertRaises (InferenceServerException ) as ex :
408
+ second_client .unregister_cuda_shared_memory ()
409
+ self .assertIn (
410
+ "Failed to unregister the following cuda shared memory regions: input0_data ,input1_data ,output0_data ,output1_data" ,
411
+ str (ex .exception ),
412
+ )
413
+
414
+ with self .assertRaises (InferenceServerException ) as ex :
415
+ second_client .unregister_cuda_shared_memory ("input0_data" )
416
+ self .assertIn (
417
+ "Cannot unregister shared memory region 'input0_data', it is currently in use." ,
418
+ str (ex .exception ),
419
+ )
420
+
421
+ with self .assertRaises (InferenceServerException ) as ex :
422
+ second_client .unregister_cuda_shared_memory ("input1_data" )
423
+ self .assertIn (
424
+ "Cannot unregister shared memory region 'input1_data', it is currently in use." ,
425
+ str (ex .exception ),
426
+ )
427
+
428
+ with self .assertRaises (InferenceServerException ) as ex :
429
+ second_client .unregister_cuda_shared_memory ("output0_data" )
430
+ self .assertIn (
431
+ "Cannot unregister shared memory region 'output0_data', it is currently in use." ,
432
+ str (ex .exception ),
433
+ )
434
+
435
+ with self .assertRaises (InferenceServerException ) as ex :
436
+ second_client .unregister_cuda_shared_memory ("output1_data" )
437
+ self .assertIn (
438
+ "Cannot unregister shared memory region 'output1_data', it is currently in use." ,
439
+ str (ex .exception ),
440
+ )
441
+
442
+ def _test_shm_not_found (self ):
443
+ second_client = httpclient .InferenceServerClient ("localhost:8000" , verbose = True )
444
+
445
+ with self .assertRaises (InferenceServerException ) as ex :
446
+ second_client .get_cuda_shared_memory_status ("input0_data" )
447
+ self .assertIn (
448
+ "Unable to find cuda shared memory region: 'input0_data'" ,
449
+ str (ex .exception ),
450
+ )
451
+
452
+ with self .assertRaises (InferenceServerException ) as ex :
453
+ second_client .get_cuda_shared_memory_status ("input1_data" )
454
+ self .assertIn (
455
+ "Unable to find cuda shared memory region: 'input1_data'" ,
456
+ str (ex .exception ),
457
+ )
458
+
459
+ with self .assertRaises (InferenceServerException ) as ex :
460
+ second_client .get_cuda_shared_memory_status ("output0_data" )
461
+ self .assertIn (
462
+ "Unable to find cuda shared memory region: 'output0_data'" ,
463
+ str (ex .exception ),
464
+ )
465
+
466
+ with self .assertRaises (InferenceServerException ) as ex :
467
+ second_client .get_cuda_shared_memory_status ("output1_data" )
468
+ self .assertIn (
469
+ "Unable to find cuda shared memory region: 'output1_data'" ,
470
+ str (ex .exception ),
471
+ )
472
+
473
+ def test_unregister_shm_during_inference_http (self ):
474
+ try :
475
+ self .triton_client .unregister_cuda_shared_memory ()
476
+ shm_handles = self ._configure_server ()
477
+
478
+ inputs = [
479
+ httpclient .InferInput ("INPUT0" , [1 , 16 ], "INT32" ),
480
+ httpclient .InferInput ("INPUT1" , [1 , 16 ], "INT32" ),
481
+ ]
482
+ outputs = [
483
+ httpclient .InferRequestedOutput ("OUTPUT0" , binary_data = True ),
484
+ httpclient .InferRequestedOutput ("OUTPUT1" , binary_data = False ),
485
+ ]
486
+
487
+ inputs [0 ].set_shared_memory ("input0_data" , self .DEFAULT_SHM_BYTE_SIZE )
488
+ inputs [1 ].set_shared_memory ("input1_data" , self .DEFAULT_SHM_BYTE_SIZE )
489
+ outputs [0 ].set_shared_memory ("output0_data" , self .DEFAULT_SHM_BYTE_SIZE )
490
+ outputs [1 ].set_shared_memory ("output1_data" , self .DEFAULT_SHM_BYTE_SIZE )
491
+
492
+ async_request = self .triton_client .async_infer (
493
+ model_name = "simple" , inputs = inputs , outputs = outputs
494
+ )
495
+
496
+ # Ensure inference started
497
+ time .sleep (2 )
498
+
499
+ # Try unregister shm regions during inference
500
+ self ._test_unregister_shm_fail ()
501
+
502
+ # Blocking call
503
+ async_request .get_result ()
504
+
505
+ # Try unregister shm regions after inference
506
+ self .triton_client .unregister_cuda_shared_memory ()
507
+ self ._test_shm_not_found ()
508
+
509
+ finally :
510
+ self ._cleanup_server (shm_handles )
511
+
512
+ def test_unregister_shm_during_inference_grpc (self ):
513
+ try :
514
+ self .triton_client .unregister_cuda_shared_memory ()
515
+ shm_handles = self ._configure_server ()
516
+
517
+ inputs = [
518
+ grpcclient .InferInput ("INPUT0" , [1 , 16 ], "INT32" ),
519
+ grpcclient .InferInput ("INPUT1" , [1 , 16 ], "INT32" ),
520
+ ]
521
+ outputs = [
522
+ grpcclient .InferRequestedOutput ("OUTPUT0" ),
523
+ grpcclient .InferRequestedOutput ("OUTPUT1" ),
524
+ ]
525
+
526
+ inputs [0 ].set_shared_memory ("input0_data" , self .DEFAULT_SHM_BYTE_SIZE )
527
+ inputs [1 ].set_shared_memory ("input1_data" , self .DEFAULT_SHM_BYTE_SIZE )
528
+ outputs [0 ].set_shared_memory ("output0_data" , self .DEFAULT_SHM_BYTE_SIZE )
529
+ outputs [1 ].set_shared_memory ("output1_data" , self .DEFAULT_SHM_BYTE_SIZE )
530
+
531
+ def callback (user_data , result , error ):
532
+ if error :
533
+ user_data .append (error )
534
+ else :
535
+ user_data .append (result )
536
+
537
+ user_data = []
538
+
539
+ self .triton_client .async_infer (
540
+ model_name = "simple" ,
541
+ inputs = inputs ,
542
+ outputs = outputs ,
543
+ callback = partial (callback , user_data ),
544
+ )
545
+
546
+ # Ensure inference started
547
+ time .sleep (2 )
548
+
549
+ # Try unregister shm regions during inference
550
+ self ._test_unregister_shm_fail ()
551
+
552
+ # Wait until the results are available in user_data
553
+ time_out = 20
554
+ while (len (user_data ) == 0 ) and time_out > 0 :
555
+ time_out = time_out - 1
556
+ time .sleep (1 )
557
+ time .sleep (2 )
558
+
559
+ # Try unregister shm regions after inference
560
+ self .triton_client .unregister_cuda_shared_memory ()
561
+ self ._test_shm_not_found ()
562
+
563
+ finally :
564
+ self ._cleanup_server (shm_handles )
565
+
566
+
399
567
if __name__ == "__main__" :
400
568
unittest .main ()
0 commit comments