Skip to content

Commit 48e615d

Browse files
authored
test: Validate request correlation ID data type (#7919)
1 parent 8ee021e commit 48e615d

File tree

2 files changed

+135
-2
lines changed

2 files changed

+135
-2
lines changed

qa/L0_sequence_corrid_batcher/sequence_corrid_batcher_test.py

Lines changed: 81 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#!/usr/bin/env python3
22

3-
# Copyright 2020-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3+
# Copyright 2020-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
44
#
55
# Redistribution and use in source and binary forms, with or without
66
# modification, are permitted provided that the following conditions
@@ -38,6 +38,8 @@
3838
import numpy as np
3939
import sequence_util as su
4040
import test_util as tu
41+
import tritonclient.http as httpclient
42+
from tritonclient.utils import InferenceServerException, np_to_triton_dtype
4143

4244
_test_system_shared_memory = bool(int(os.environ.get("TEST_SYSTEM_SHARED_MEMORY", 0)))
4345
_test_cuda_shared_memory = bool(int(os.environ.get("TEST_CUDA_SHARED_MEMORY", 0)))
@@ -77,6 +79,12 @@ def get_expected_result(self, expected_result, corrid, value, trial, flag_str=No
7779
expected_result += corrid
7880
return expected_result
7981

82+
def data_type_to_string(self, dtype):
83+
if dtype == "TYPE_STRING":
84+
return "BYTES"
85+
else:
86+
return dtype.replace("TYPE_", "")
87+
8088
def test_skip_batch(self):
8189
# Test model instances together are configured with
8290
# total-batch-size 4. Send four sequences in parallel where
@@ -221,6 +229,78 @@ def test_skip_batch(self):
221229
self.cleanup_shm_regions(precreated_shm2_handles)
222230
self.cleanup_shm_regions(precreated_shm3_handles)
223231

232+
def test_corrid_data_type(self):
233+
model_name = "add_sub"
234+
expected_corrid_dtype = os.environ["TRITONSERVER_CORRID_DATA_TYPE"]
235+
236+
for corrid, corrid_dtype in [("corrid", "TYPE_STRING"), (123, "TYPE_UINT64")]:
237+
# Check if the corrid data type matches the expected corrid data type specified in the model config
238+
dtypes_match = True
239+
if (corrid_dtype == "TYPE_STRING") and (
240+
expected_corrid_dtype != "TYPE_STRING"
241+
):
242+
dtypes_match = False
243+
elif (corrid_dtype == "TYPE_UINT64") and (
244+
expected_corrid_dtype
245+
not in ["TYPE_UINT32", "TYPE_INT32", "TYPE_UINT64", "TYPE_INT64"]
246+
):
247+
dtypes_match = False
248+
249+
with httpclient.InferenceServerClient("localhost:8000") as client:
250+
input0_data = np.random.rand(16).astype(np.float32)
251+
input1_data = np.random.rand(16).astype(np.float32)
252+
inputs = [
253+
httpclient.InferInput(
254+
"INPUT0",
255+
input0_data.shape,
256+
np_to_triton_dtype(input0_data.dtype),
257+
),
258+
httpclient.InferInput(
259+
"INPUT1",
260+
input1_data.shape,
261+
np_to_triton_dtype(input1_data.dtype),
262+
),
263+
]
264+
265+
inputs[0].set_data_from_numpy(input0_data)
266+
inputs[1].set_data_from_numpy(input1_data)
267+
268+
if not dtypes_match:
269+
with self.assertRaises(InferenceServerException) as e:
270+
client.infer(
271+
model_name,
272+
inputs,
273+
sequence_id=corrid,
274+
sequence_start=True,
275+
sequence_end=False,
276+
)
277+
err_str = str(e.exception)
278+
self.assertIn(
279+
f"sequence batching control 'CORRID' data-type is '{self.data_type_to_string(corrid_dtype)}', but model '{model_name}' expects '{self.data_type_to_string(expected_corrid_dtype)}'",
280+
err_str,
281+
)
282+
else:
283+
response = client.infer(
284+
model_name,
285+
inputs,
286+
sequence_id=corrid,
287+
sequence_start=True,
288+
sequence_end=False,
289+
)
290+
response.get_response()
291+
output0_data = response.as_numpy("OUTPUT0")
292+
output1_data = response.as_numpy("OUTPUT1")
293+
294+
self.assertTrue(
295+
np.allclose(input0_data + input1_data, output0_data),
296+
"add_sub example error: incorrect sum",
297+
)
298+
299+
self.assertTrue(
300+
np.allclose(input0_data - input1_data, output1_data),
301+
"add_sub example error: incorrect difference",
302+
)
303+
224304

225305
if __name__ == "__main__":
226306
unittest.main()

qa/L0_sequence_corrid_batcher/test.sh

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
#!/bin/bash
2-
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
2+
# Copyright (c) 2020-2025, NVIDIA CORPORATION. All rights reserved.
33
#
44
# Redistribution and use in source and binary forms, with or without
55
# modification, are permitted provided that the following conditions
@@ -121,6 +121,59 @@ for model_trial in 4; do
121121
done
122122
done
123123

124+
# Test correlation ID data type
125+
mkdir -p corrid_data_type/add_sub/1
126+
cp ../python_models/add_sub/model.py corrid_data_type/add_sub/1
127+
128+
for corrid_data_type in TYPE_STRING TYPE_UINT32 TYPE_INT32 TYPE_UINT64 TYPE_INT64; do
129+
(cd corrid_data_type/add_sub && \
130+
cp ../../../python_models/add_sub/config.pbtxt . && \
131+
echo "sequence_batching { \
132+
control_input [{ \
133+
name: \"CORRID\" \
134+
control [{ \
135+
kind: CONTROL_SEQUENCE_CORRID \
136+
data_type: $corrid_data_type \
137+
}]
138+
}] \
139+
}" >> config.pbtxt)
140+
MODEL_DIR=corrid_data_type
141+
142+
for i in test_corrid_data_type ; do
143+
export TRITONSERVER_CORRID_DATA_TYPE=$corrid_data_type
144+
SERVER_ARGS="--model-repository=`pwd`/$MODEL_DIR"
145+
SERVER_LOG="./$i.$MODEL_DIR.server.log"
146+
run_server
147+
if [ "$SERVER_PID" == "0" ]; then
148+
echo -e "\n***\n*** Failed to start $SERVER\n***"
149+
cat $SERVER_LOG
150+
exit 1
151+
fi
152+
153+
echo "Test: $i, repository $MODEL_DIR" >>$CLIENT_LOG
154+
155+
set +e
156+
python $BATCHER_TEST SequenceCorrIDBatcherTest.$i >>$CLIENT_LOG 2>&1
157+
if [ $? -ne 0 ]; then
158+
echo -e "\n***\n*** Test $i Failed\n***" >>$CLIENT_LOG
159+
echo -e "\n***\n*** Test $i Failed\n***"
160+
RET=1
161+
else
162+
check_test_results $TEST_RESULT_FILE 1
163+
if [ $? -ne 0 ]; then
164+
cat $CLIENT_LOG
165+
echo -e "\n***\n*** Test Result Verification Failed\n***"
166+
RET=1
167+
fi
168+
fi
169+
set -e
170+
171+
unset TRITONSERVER_CORRID_DATA_TYPE
172+
kill $SERVER_PID
173+
wait $SERVER_PID
174+
done
175+
done
176+
124177
if [ $RET -eq 0 ]; then
125178
echo -e "\n***\n*** Test Passed\n***"
126179
else

0 commit comments

Comments
 (0)