|
1 | 1 | #!/usr/bin/env python3
|
2 | 2 |
|
3 |
| -# Copyright 2020-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. |
| 3 | +# Copyright 2020-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. |
4 | 4 | #
|
5 | 5 | # Redistribution and use in source and binary forms, with or without
|
6 | 6 | # modification, are permitted provided that the following conditions
|
|
38 | 38 | import numpy as np
|
39 | 39 | import sequence_util as su
|
40 | 40 | import test_util as tu
|
| 41 | +import tritonclient.http as httpclient |
| 42 | +from tritonclient.utils import InferenceServerException, np_to_triton_dtype |
41 | 43 |
|
42 | 44 | _test_system_shared_memory = bool(int(os.environ.get("TEST_SYSTEM_SHARED_MEMORY", 0)))
|
43 | 45 | _test_cuda_shared_memory = bool(int(os.environ.get("TEST_CUDA_SHARED_MEMORY", 0)))
|
@@ -77,6 +79,12 @@ def get_expected_result(self, expected_result, corrid, value, trial, flag_str=No
|
77 | 79 | expected_result += corrid
|
78 | 80 | return expected_result
|
79 | 81 |
|
| 82 | + def data_type_to_string(self, dtype): |
| 83 | + if dtype == "TYPE_STRING": |
| 84 | + return "BYTES" |
| 85 | + else: |
| 86 | + return dtype.replace("TYPE_", "") |
| 87 | + |
80 | 88 | def test_skip_batch(self):
|
81 | 89 | # Test model instances together are configured with
|
82 | 90 | # total-batch-size 4. Send four sequences in parallel where
|
@@ -221,6 +229,78 @@ def test_skip_batch(self):
|
221 | 229 | self.cleanup_shm_regions(precreated_shm2_handles)
|
222 | 230 | self.cleanup_shm_regions(precreated_shm3_handles)
|
223 | 231 |
|
| 232 | + def test_corrid_data_type(self): |
| 233 | + model_name = "add_sub" |
| 234 | + expected_corrid_dtype = os.environ["TRITONSERVER_CORRID_DATA_TYPE"] |
| 235 | + |
| 236 | + for corrid, corrid_dtype in [("corrid", "TYPE_STRING"), (123, "TYPE_UINT64")]: |
| 237 | + # Check if the corrid data type matches the expected corrid data type specified in the model config |
| 238 | + dtypes_match = True |
| 239 | + if (corrid_dtype == "TYPE_STRING") and ( |
| 240 | + expected_corrid_dtype != "TYPE_STRING" |
| 241 | + ): |
| 242 | + dtypes_match = False |
| 243 | + elif (corrid_dtype == "TYPE_UINT64") and ( |
| 244 | + expected_corrid_dtype |
| 245 | + not in ["TYPE_UINT32", "TYPE_INT32", "TYPE_UINT64", "TYPE_INT64"] |
| 246 | + ): |
| 247 | + dtypes_match = False |
| 248 | + |
| 249 | + with httpclient.InferenceServerClient("localhost:8000") as client: |
| 250 | + input0_data = np.random.rand(16).astype(np.float32) |
| 251 | + input1_data = np.random.rand(16).astype(np.float32) |
| 252 | + inputs = [ |
| 253 | + httpclient.InferInput( |
| 254 | + "INPUT0", |
| 255 | + input0_data.shape, |
| 256 | + np_to_triton_dtype(input0_data.dtype), |
| 257 | + ), |
| 258 | + httpclient.InferInput( |
| 259 | + "INPUT1", |
| 260 | + input1_data.shape, |
| 261 | + np_to_triton_dtype(input1_data.dtype), |
| 262 | + ), |
| 263 | + ] |
| 264 | + |
| 265 | + inputs[0].set_data_from_numpy(input0_data) |
| 266 | + inputs[1].set_data_from_numpy(input1_data) |
| 267 | + |
| 268 | + if not dtypes_match: |
| 269 | + with self.assertRaises(InferenceServerException) as e: |
| 270 | + client.infer( |
| 271 | + model_name, |
| 272 | + inputs, |
| 273 | + sequence_id=corrid, |
| 274 | + sequence_start=True, |
| 275 | + sequence_end=False, |
| 276 | + ) |
| 277 | + err_str = str(e.exception) |
| 278 | + self.assertIn( |
| 279 | + f"sequence batching control 'CORRID' data-type is '{self.data_type_to_string(corrid_dtype)}', but model '{model_name}' expects '{self.data_type_to_string(expected_corrid_dtype)}'", |
| 280 | + err_str, |
| 281 | + ) |
| 282 | + else: |
| 283 | + response = client.infer( |
| 284 | + model_name, |
| 285 | + inputs, |
| 286 | + sequence_id=corrid, |
| 287 | + sequence_start=True, |
| 288 | + sequence_end=False, |
| 289 | + ) |
| 290 | + response.get_response() |
| 291 | + output0_data = response.as_numpy("OUTPUT0") |
| 292 | + output1_data = response.as_numpy("OUTPUT1") |
| 293 | + |
| 294 | + self.assertTrue( |
| 295 | + np.allclose(input0_data + input1_data, output0_data), |
| 296 | + "add_sub example error: incorrect sum", |
| 297 | + ) |
| 298 | + |
| 299 | + self.assertTrue( |
| 300 | + np.allclose(input0_data - input1_data, output1_data), |
| 301 | + "add_sub example error: incorrect difference", |
| 302 | + ) |
| 303 | + |
224 | 304 |
|
225 | 305 | if __name__ == "__main__":
|
226 | 306 | unittest.main()
|
0 commit comments