Skip to content

Commit 29c7a28

Browse files
pskiran1GuanLuo
andauthored
Cherry-pick: fix: Add reference count tracking for shared memory regions (#7567) (#7612)
Co-authored-by: GuanLuo <41310872+GuanLuo@users.noreply.github.com>
1 parent 9ae25db commit 29c7a28

File tree

14 files changed

+886
-222
lines changed

14 files changed

+886
-222
lines changed

qa/L0_cuda_shared_memory/cuda_shared_memory_test.py

Lines changed: 240 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -31,18 +31,20 @@
3131
sys.path.append("../common")
3232

3333
import os
34+
import time
3435
import unittest
36+
from functools import partial
3537

3638
import infer_util as iu
3739
import numpy as np
3840
import test_util as tu
3941
import tritonclient.grpc as grpcclient
4042
import tritonclient.http as httpclient
41-
import tritonshmutils.cuda_shared_memory as cshm
43+
import tritonclient.utils.cuda_shared_memory as cshm
4244
from tritonclient.utils import *
4345

4446

45-
class CudaSharedMemoryTest(tu.TestResultCollector):
47+
class CudaSharedMemoryTestBase(tu.TestResultCollector):
4648
DEFAULT_SHM_BYTE_SIZE = 64
4749

4850
def setUp(self):
@@ -61,76 +63,6 @@ def _setup_client(self):
6163
self.url, verbose=True
6264
)
6365

64-
def test_invalid_create_shm(self):
65-
# Raises error since tried to create invalid cuda shared memory region
66-
try:
67-
shm_op0_handle = cshm.create_shared_memory_region("dummy_data", -1, 0)
68-
cshm.destroy_shared_memory_region(shm_op0_handle)
69-
except Exception as ex:
70-
self.assertEqual(str(ex), "unable to create cuda shared memory handle")
71-
72-
def test_valid_create_set_register(self):
73-
# Create a valid cuda shared memory region, fill data in it and register
74-
shm_op0_handle = cshm.create_shared_memory_region("dummy_data", 8, 0)
75-
cshm.set_shared_memory_region(
76-
shm_op0_handle, [np.array([1, 2], dtype=np.float32)]
77-
)
78-
self.triton_client.register_cuda_shared_memory(
79-
"dummy_data", cshm.get_raw_handle(shm_op0_handle), 0, 8
80-
)
81-
shm_status = self.triton_client.get_cuda_shared_memory_status()
82-
if self.protocol == "http":
83-
self.assertEqual(len(shm_status), 1)
84-
else:
85-
self.assertEqual(len(shm_status.regions), 1)
86-
cshm.destroy_shared_memory_region(shm_op0_handle)
87-
88-
def test_unregister_before_register(self):
89-
# Create a valid cuda shared memory region and unregister before register
90-
shm_op0_handle = cshm.create_shared_memory_region("dummy_data", 8, 0)
91-
self.triton_client.unregister_cuda_shared_memory("dummy_data")
92-
shm_status = self.triton_client.get_cuda_shared_memory_status()
93-
if self.protocol == "http":
94-
self.assertEqual(len(shm_status), 0)
95-
else:
96-
self.assertEqual(len(shm_status.regions), 0)
97-
cshm.destroy_shared_memory_region(shm_op0_handle)
98-
99-
def test_unregister_after_register(self):
100-
# Create a valid cuda shared memory region and unregister after register
101-
shm_op0_handle = cshm.create_shared_memory_region("dummy_data", 8, 0)
102-
self.triton_client.register_cuda_shared_memory(
103-
"dummy_data", cshm.get_raw_handle(shm_op0_handle), 0, 8
104-
)
105-
self.triton_client.unregister_cuda_shared_memory("dummy_data")
106-
shm_status = self.triton_client.get_cuda_shared_memory_status()
107-
if self.protocol == "http":
108-
self.assertEqual(len(shm_status), 0)
109-
else:
110-
self.assertEqual(len(shm_status.regions), 0)
111-
cshm.destroy_shared_memory_region(shm_op0_handle)
112-
113-
def test_reregister_after_register(self):
114-
# Create a valid cuda shared memory region and unregister after register
115-
shm_op0_handle = cshm.create_shared_memory_region("dummy_data", 8, 0)
116-
self.triton_client.register_cuda_shared_memory(
117-
"dummy_data", cshm.get_raw_handle(shm_op0_handle), 0, 8
118-
)
119-
try:
120-
self.triton_client.register_cuda_shared_memory(
121-
"dummy_data", cshm.get_raw_handle(shm_op0_handle), 0, 8
122-
)
123-
except Exception as ex:
124-
self.assertIn(
125-
"shared memory region 'dummy_data' already in manager", str(ex)
126-
)
127-
shm_status = self.triton_client.get_cuda_shared_memory_status()
128-
if self.protocol == "http":
129-
self.assertEqual(len(shm_status), 1)
130-
else:
131-
self.assertEqual(len(shm_status.regions), 1)
132-
cshm.destroy_shared_memory_region(shm_op0_handle)
133-
13466
def _configure_server(
13567
self,
13668
create_byte_size=DEFAULT_SHM_BYTE_SIZE,
@@ -205,6 +137,78 @@ def _cleanup_server(self, shm_handles):
205137
for shm_handle in shm_handles:
206138
cshm.destroy_shared_memory_region(shm_handle)
207139

140+
141+
class CudaSharedMemoryTest(CudaSharedMemoryTestBase):
142+
def test_invalid_create_shm(self):
143+
# Raises error since tried to create invalid cuda shared memory region
144+
try:
145+
shm_op0_handle = cshm.create_shared_memory_region("dummy_data", -1, 0)
146+
cshm.destroy_shared_memory_region(shm_op0_handle)
147+
except Exception as ex:
148+
self.assertEqual(str(ex), "unable to create cuda shared memory handle")
149+
150+
def test_valid_create_set_register(self):
151+
# Create a valid cuda shared memory region, fill data in it and register
152+
shm_op0_handle = cshm.create_shared_memory_region("dummy_data", 8, 0)
153+
cshm.set_shared_memory_region(
154+
shm_op0_handle, [np.array([1, 2], dtype=np.float32)]
155+
)
156+
self.triton_client.register_cuda_shared_memory(
157+
"dummy_data", cshm.get_raw_handle(shm_op0_handle), 0, 8
158+
)
159+
shm_status = self.triton_client.get_cuda_shared_memory_status()
160+
if self.protocol == "http":
161+
self.assertEqual(len(shm_status), 1)
162+
else:
163+
self.assertEqual(len(shm_status.regions), 1)
164+
cshm.destroy_shared_memory_region(shm_op0_handle)
165+
166+
def test_unregister_before_register(self):
167+
# Create a valid cuda shared memory region and unregister before register
168+
shm_op0_handle = cshm.create_shared_memory_region("dummy_data", 8, 0)
169+
self.triton_client.unregister_cuda_shared_memory("dummy_data")
170+
shm_status = self.triton_client.get_cuda_shared_memory_status()
171+
if self.protocol == "http":
172+
self.assertEqual(len(shm_status), 0)
173+
else:
174+
self.assertEqual(len(shm_status.regions), 0)
175+
cshm.destroy_shared_memory_region(shm_op0_handle)
176+
177+
def test_unregister_after_register(self):
178+
# Create a valid cuda shared memory region and unregister after register
179+
shm_op0_handle = cshm.create_shared_memory_region("dummy_data", 8, 0)
180+
self.triton_client.register_cuda_shared_memory(
181+
"dummy_data", cshm.get_raw_handle(shm_op0_handle), 0, 8
182+
)
183+
self.triton_client.unregister_cuda_shared_memory("dummy_data")
184+
shm_status = self.triton_client.get_cuda_shared_memory_status()
185+
if self.protocol == "http":
186+
self.assertEqual(len(shm_status), 0)
187+
else:
188+
self.assertEqual(len(shm_status.regions), 0)
189+
cshm.destroy_shared_memory_region(shm_op0_handle)
190+
191+
def test_reregister_after_register(self):
192+
# Create a valid cuda shared memory region and unregister after register
193+
shm_op0_handle = cshm.create_shared_memory_region("dummy_data", 8, 0)
194+
self.triton_client.register_cuda_shared_memory(
195+
"dummy_data", cshm.get_raw_handle(shm_op0_handle), 0, 8
196+
)
197+
try:
198+
self.triton_client.register_cuda_shared_memory(
199+
"dummy_data", cshm.get_raw_handle(shm_op0_handle), 0, 8
200+
)
201+
except Exception as ex:
202+
self.assertIn(
203+
"shared memory region 'dummy_data' already in manager", str(ex)
204+
)
205+
shm_status = self.triton_client.get_cuda_shared_memory_status()
206+
if self.protocol == "http":
207+
self.assertEqual(len(shm_status), 1)
208+
else:
209+
self.assertEqual(len(shm_status.regions), 1)
210+
cshm.destroy_shared_memory_region(shm_op0_handle)
211+
208212
def test_unregister_after_inference(self):
209213
# Unregister after inference
210214
error_msg = []
@@ -396,5 +400,169 @@ def test_infer_byte_size_out_of_bound(self):
396400
self._cleanup_server(shm_handles)
397401

398402

403+
class TestCudaSharedMemoryUnregister(CudaSharedMemoryTestBase):
404+
def _test_unregister_shm_fail(self):
405+
second_client = httpclient.InferenceServerClient("localhost:8000", verbose=True)
406+
407+
with self.assertRaises(InferenceServerException) as ex:
408+
second_client.unregister_cuda_shared_memory()
409+
self.assertIn(
410+
"Failed to unregister the following cuda shared memory regions: input0_data ,input1_data ,output0_data ,output1_data",
411+
str(ex.exception),
412+
)
413+
414+
with self.assertRaises(InferenceServerException) as ex:
415+
second_client.unregister_cuda_shared_memory("input0_data")
416+
self.assertIn(
417+
"Cannot unregister shared memory region 'input0_data', it is currently in use.",
418+
str(ex.exception),
419+
)
420+
421+
with self.assertRaises(InferenceServerException) as ex:
422+
second_client.unregister_cuda_shared_memory("input1_data")
423+
self.assertIn(
424+
"Cannot unregister shared memory region 'input1_data', it is currently in use.",
425+
str(ex.exception),
426+
)
427+
428+
with self.assertRaises(InferenceServerException) as ex:
429+
second_client.unregister_cuda_shared_memory("output0_data")
430+
self.assertIn(
431+
"Cannot unregister shared memory region 'output0_data', it is currently in use.",
432+
str(ex.exception),
433+
)
434+
435+
with self.assertRaises(InferenceServerException) as ex:
436+
second_client.unregister_cuda_shared_memory("output1_data")
437+
self.assertIn(
438+
"Cannot unregister shared memory region 'output1_data', it is currently in use.",
439+
str(ex.exception),
440+
)
441+
442+
def _test_shm_not_found(self):
443+
second_client = httpclient.InferenceServerClient("localhost:8000", verbose=True)
444+
445+
with self.assertRaises(InferenceServerException) as ex:
446+
second_client.get_cuda_shared_memory_status("input0_data")
447+
self.assertIn(
448+
"Unable to find cuda shared memory region: 'input0_data'",
449+
str(ex.exception),
450+
)
451+
452+
with self.assertRaises(InferenceServerException) as ex:
453+
second_client.get_cuda_shared_memory_status("input1_data")
454+
self.assertIn(
455+
"Unable to find cuda shared memory region: 'input1_data'",
456+
str(ex.exception),
457+
)
458+
459+
with self.assertRaises(InferenceServerException) as ex:
460+
second_client.get_cuda_shared_memory_status("output0_data")
461+
self.assertIn(
462+
"Unable to find cuda shared memory region: 'output0_data'",
463+
str(ex.exception),
464+
)
465+
466+
with self.assertRaises(InferenceServerException) as ex:
467+
second_client.get_cuda_shared_memory_status("output1_data")
468+
self.assertIn(
469+
"Unable to find cuda shared memory region: 'output1_data'",
470+
str(ex.exception),
471+
)
472+
473+
def test_unregister_shm_during_inference_http(self):
474+
try:
475+
self.triton_client.unregister_cuda_shared_memory()
476+
shm_handles = self._configure_server()
477+
478+
inputs = [
479+
httpclient.InferInput("INPUT0", [1, 16], "INT32"),
480+
httpclient.InferInput("INPUT1", [1, 16], "INT32"),
481+
]
482+
outputs = [
483+
httpclient.InferRequestedOutput("OUTPUT0", binary_data=True),
484+
httpclient.InferRequestedOutput("OUTPUT1", binary_data=False),
485+
]
486+
487+
inputs[0].set_shared_memory("input0_data", self.DEFAULT_SHM_BYTE_SIZE)
488+
inputs[1].set_shared_memory("input1_data", self.DEFAULT_SHM_BYTE_SIZE)
489+
outputs[0].set_shared_memory("output0_data", self.DEFAULT_SHM_BYTE_SIZE)
490+
outputs[1].set_shared_memory("output1_data", self.DEFAULT_SHM_BYTE_SIZE)
491+
492+
async_request = self.triton_client.async_infer(
493+
model_name="simple", inputs=inputs, outputs=outputs
494+
)
495+
496+
# Ensure inference started
497+
time.sleep(2)
498+
499+
# Try unregister shm regions during inference
500+
self._test_unregister_shm_fail()
501+
502+
# Blocking call
503+
async_request.get_result()
504+
505+
# Try unregister shm regions after inference
506+
self.triton_client.unregister_cuda_shared_memory()
507+
self._test_shm_not_found()
508+
509+
finally:
510+
self._cleanup_server(shm_handles)
511+
512+
def test_unregister_shm_during_inference_grpc(self):
513+
try:
514+
self.triton_client.unregister_cuda_shared_memory()
515+
shm_handles = self._configure_server()
516+
517+
inputs = [
518+
grpcclient.InferInput("INPUT0", [1, 16], "INT32"),
519+
grpcclient.InferInput("INPUT1", [1, 16], "INT32"),
520+
]
521+
outputs = [
522+
grpcclient.InferRequestedOutput("OUTPUT0"),
523+
grpcclient.InferRequestedOutput("OUTPUT1"),
524+
]
525+
526+
inputs[0].set_shared_memory("input0_data", self.DEFAULT_SHM_BYTE_SIZE)
527+
inputs[1].set_shared_memory("input1_data", self.DEFAULT_SHM_BYTE_SIZE)
528+
outputs[0].set_shared_memory("output0_data", self.DEFAULT_SHM_BYTE_SIZE)
529+
outputs[1].set_shared_memory("output1_data", self.DEFAULT_SHM_BYTE_SIZE)
530+
531+
def callback(user_data, result, error):
532+
if error:
533+
user_data.append(error)
534+
else:
535+
user_data.append(result)
536+
537+
user_data = []
538+
539+
self.triton_client.async_infer(
540+
model_name="simple",
541+
inputs=inputs,
542+
outputs=outputs,
543+
callback=partial(callback, user_data),
544+
)
545+
546+
# Ensure inference started
547+
time.sleep(2)
548+
549+
# Try unregister shm regions during inference
550+
self._test_unregister_shm_fail()
551+
552+
# Wait until the results are available in user_data
553+
time_out = 20
554+
while (len(user_data) == 0) and time_out > 0:
555+
time_out = time_out - 1
556+
time.sleep(1)
557+
time.sleep(2)
558+
559+
# Try unregister shm regions after inference
560+
self.triton_client.unregister_cuda_shared_memory()
561+
self._test_shm_not_found()
562+
563+
finally:
564+
self._cleanup_server(shm_handles)
565+
566+
399567
if __name__ == "__main__":
400568
unittest.main()

qa/L0_cuda_shared_memory/test.sh

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,47 @@ for i in \
8484
done
8585
done
8686

87+
mkdir -p python_models/simple/1/
88+
cp ../python_models/execute_delayed_model/model.py ./python_models/simple/1/
89+
cp ../python_models/execute_delayed_model/config.pbtxt ./python_models/simple/
90+
sed -i 's/KIND_CPU/KIND_GPU/g' ./python_models/simple/config.pbtxt
91+
92+
for client_type in http grpc; do
93+
SERVER_ARGS="--model-repository=`pwd`/python_models --log-verbose=1 ${SERVER_ARGS_EXTRA}"
94+
SERVER_LOG="./unregister_shm.$client_type.server.log"
95+
run_server
96+
if [ "$SERVER_PID" == "0" ]; then
97+
echo -e "\n***\n*** Failed to start $SERVER\n***"
98+
cat $SERVER_LOG
99+
exit 1
100+
fi
101+
102+
export CLIENT_TYPE=$client_type
103+
CLIENT_LOG="./unregister_shm.$client_type.client.log"
104+
set +e
105+
python3 $SHM_TEST TestCudaSharedMemoryUnregister.test_unregister_shm_during_inference_$client_type >>$CLIENT_LOG 2>&1
106+
if [ $? -ne 0 ]; then
107+
cat $CLIENT_LOG
108+
echo -e "\n***\n*** Test Failed\n***"
109+
RET=1
110+
else
111+
check_test_results $TEST_RESULT_FILE 1
112+
if [ $? -ne 0 ]; then
113+
cat $TEST_RESULT_FILE
114+
echo -e "\n***\n*** Test Result Verification Failed\n***"
115+
RET=1
116+
fi
117+
fi
118+
119+
kill $SERVER_PID
120+
wait $SERVER_PID
121+
if [ $? -ne 0 ]; then
122+
echo -e "\n***\n*** Test Server shut down non-gracefully\n***"
123+
RET=1
124+
fi
125+
set -e
126+
done
127+
87128
if [ $RET -eq 0 ]; then
88129
echo -e "\n***\n*** Test Passed\n***"
89130
else

0 commit comments

Comments
 (0)