|
1 | 1 | #!/usr/bin/env python3
|
2 | 2 |
|
3 |
| -# Copyright 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. |
| 3 | +# Copyright 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. |
4 | 4 | #
|
5 | 5 | # Redistribution and use in source and binary forms, with or without
|
6 | 6 | # modification, are permitted provided that the following conditions
|
|
46 | 46 | _tritonserver_ipaddr = os.environ.get("TRITONSERVER_IPADDR", "localhost")
|
47 | 47 |
|
48 | 48 |
|
| 49 | +def prepare_decoupled_bls_cancel_inputs(input_value, max_sum_value, ignore_cancel): |
| 50 | + input_data = np.array([input_value], dtype=np.int32) |
| 51 | + max_sum_data = np.array([max_sum_value], dtype=np.int32) |
| 52 | + ignore_cancel_data = np.array([ignore_cancel], dtype=np.bool_) |
| 53 | + inputs = [ |
| 54 | + grpcclient.InferInput( |
| 55 | + "INPUT", |
| 56 | + input_data.shape, |
| 57 | + np_to_triton_dtype(input_data.dtype), |
| 58 | + ), |
| 59 | + grpcclient.InferInput( |
| 60 | + "MAX_SUM", |
| 61 | + max_sum_data.shape, |
| 62 | + np_to_triton_dtype(max_sum_data.dtype), |
| 63 | + ), |
| 64 | + grpcclient.InferInput( |
| 65 | + "IGNORE_CANCEL", |
| 66 | + ignore_cancel_data.shape, |
| 67 | + np_to_triton_dtype(ignore_cancel_data.dtype), |
| 68 | + ), |
| 69 | + ] |
| 70 | + inputs[0].set_data_from_numpy(input_data) |
| 71 | + inputs[1].set_data_from_numpy(max_sum_data) |
| 72 | + inputs[2].set_data_from_numpy(ignore_cancel_data) |
| 73 | + |
| 74 | + return inputs |
| 75 | + |
| 76 | + |
49 | 77 | class UserData:
|
50 | 78 | def __init__(self):
|
51 | 79 | self._completed_requests = queue.Queue()
|
@@ -324,6 +352,171 @@ def test_decoupled_execute_cancel(self):
|
324 | 352 | self.assertIn("[execute_cancel] Request not cancelled at 1.0 s", log_text)
|
325 | 353 | self.assertIn("[execute_cancel] Request cancelled at ", log_text)
|
326 | 354 |
|
| 355 | + def test_decoupled_bls_cancel(self): |
| 356 | + model_names = ["decoupled_bls_cancel", "decoupled_bls_async_cancel"] |
| 357 | + input_value = 1 |
| 358 | + max_sum_value = 10 |
| 359 | + ignore_cancel = False |
| 360 | + user_data = UserData() |
| 361 | + for model_name in model_names: |
| 362 | + with self._shm_leak_detector.Probe() as shm_probe: |
| 363 | + with grpcclient.InferenceServerClient( |
| 364 | + f"{_tritonserver_ipaddr}:8001" |
| 365 | + ) as client: |
| 366 | + client.start_stream(callback=partial(callback, user_data)) |
| 367 | + inputs = prepare_decoupled_bls_cancel_inputs( |
| 368 | + input_value=input_value, |
| 369 | + max_sum_value=max_sum_value, |
| 370 | + ignore_cancel=ignore_cancel, |
| 371 | + ) |
| 372 | + client.async_stream_infer(model_name, inputs) |
| 373 | + |
| 374 | + # Check the results of the decoupled model using BLS |
| 375 | + def check_result(result): |
| 376 | + # Make sure the result is not an exception |
| 377 | + self.assertIsNot(type(result), InferenceServerException) |
| 378 | + is_cancelled = result.as_numpy("IS_CANCELLED") |
| 379 | + self.assertTrue( |
| 380 | + is_cancelled[0], |
| 381 | + "error: expected the request to be cancelled", |
| 382 | + ) |
| 383 | + |
| 384 | + max_sum_data = np.array([max_sum_value], dtype=np.int32) |
| 385 | + sum_data = result.as_numpy("SUM") |
| 386 | + self.assertIsNotNone(sum_data, "error: expected 'SUM'") |
| 387 | + self.assertTrue( |
| 388 | + np.array_equal(sum_data, max_sum_data), |
| 389 | + "error: expected output {} to match input {}".format( |
| 390 | + sum_data, max_sum_data |
| 391 | + ), |
| 392 | + ) |
| 393 | + |
| 394 | + result = user_data._completed_requests.get() |
| 395 | + check_result(result) |
| 396 | + |
| 397 | + def test_decoupled_bls_ignore_cancel(self): |
| 398 | + model_names = ["decoupled_bls_cancel", "decoupled_bls_async_cancel"] |
| 399 | + input_value = 1 |
| 400 | + max_sum_value = 10 |
| 401 | + ignore_cancel = True |
| 402 | + user_data = UserData() |
| 403 | + for model_name in model_names: |
| 404 | + with self._shm_leak_detector.Probe() as shm_probe: |
| 405 | + with grpcclient.InferenceServerClient( |
| 406 | + f"{_tritonserver_ipaddr}:8001" |
| 407 | + ) as client: |
| 408 | + client.start_stream(callback=partial(callback, user_data)) |
| 409 | + inputs = prepare_decoupled_bls_cancel_inputs( |
| 410 | + input_value=input_value, |
| 411 | + max_sum_value=max_sum_value, |
| 412 | + ignore_cancel=ignore_cancel, |
| 413 | + ) |
| 414 | + client.async_stream_infer(model_name, inputs) |
| 415 | + |
| 416 | + # Check the results of the decoupled model using BLS |
| 417 | + def check_result(result): |
| 418 | + # Make sure the result is not an exception |
| 419 | + self.assertIsNot(type(result), InferenceServerException) |
| 420 | + is_cancelled = result.as_numpy("IS_CANCELLED") |
| 421 | + self.assertFalse( |
| 422 | + is_cancelled[0], |
| 423 | + "error: expected the request not being cancelled", |
| 424 | + ) |
| 425 | + |
| 426 | + max_sum_data = np.array([max_sum_value], dtype=np.int32) |
| 427 | + sum_data = result.as_numpy("SUM") |
| 428 | + self.assertIsNotNone(sum_data, "error: expected 'SUM'") |
| 429 | + self.assertTrue( |
| 430 | + sum_data > max_sum_data, |
| 431 | + "error: expected sum_data {} to be greater than max_sum_data {}".format( |
| 432 | + sum_data, max_sum_data |
| 433 | + ), |
| 434 | + ) |
| 435 | + |
| 436 | + result = user_data._completed_requests.get() |
| 437 | + check_result(result) |
| 438 | + |
| 439 | + def test_decoupled_bls_cancel_after_cancellation(self): |
| 440 | + model_name = "decoupled_bls_cancel_after_complete" |
| 441 | + input_value = 1 |
| 442 | + max_sum_value = 10 |
| 443 | + ignore_cancel = False |
| 444 | + user_data = UserData() |
| 445 | + with self._shm_leak_detector.Probe() as shm_probe: |
| 446 | + with grpcclient.InferenceServerClient( |
| 447 | + f"{_tritonserver_ipaddr}:8001" |
| 448 | + ) as client: |
| 449 | + client.start_stream(callback=partial(callback, user_data)) |
| 450 | + inputs = prepare_decoupled_bls_cancel_inputs( |
| 451 | + input_value=input_value, |
| 452 | + max_sum_value=max_sum_value, |
| 453 | + ignore_cancel=ignore_cancel, |
| 454 | + ) |
| 455 | + client.async_stream_infer(model_name, inputs) |
| 456 | + |
| 457 | + # Check the results of the decoupled model using BLS |
| 458 | + def check_result(result): |
| 459 | + # Make sure the result is not an exception |
| 460 | + self.assertIsNot(type(result), InferenceServerException) |
| 461 | + is_cancelled = result.as_numpy("IS_CANCELLED") |
| 462 | + self.assertTrue( |
| 463 | + is_cancelled[0], "error: expected the request to be cancelled" |
| 464 | + ) |
| 465 | + |
| 466 | + max_sum_data = np.array([max_sum_value], dtype=np.int32) |
| 467 | + sum_data = result.as_numpy("SUM") |
| 468 | + self.assertIsNotNone(sum_data, "error: expected 'SUM'") |
| 469 | + self.assertTrue( |
| 470 | + np.array_equal(sum_data, max_sum_data), |
| 471 | + "error: expected output {} to match input {}".format( |
| 472 | + sum_data, max_sum_data |
| 473 | + ), |
| 474 | + ) |
| 475 | + |
| 476 | + result = user_data._completed_requests.get() |
| 477 | + check_result(result) |
| 478 | + |
| 479 | + def test_decoupled_bls_cancel_after_completion(self): |
| 480 | + model_name = "decoupled_bls_cancel_after_complete" |
| 481 | + input_value = 1 |
| 482 | + max_sum_value = 25 |
| 483 | + ignore_cancel = False |
| 484 | + user_data = UserData() |
| 485 | + with self._shm_leak_detector.Probe() as shm_probe: |
| 486 | + with grpcclient.InferenceServerClient( |
| 487 | + f"{_tritonserver_ipaddr}:8001" |
| 488 | + ) as client: |
| 489 | + client.start_stream(callback=partial(callback, user_data)) |
| 490 | + inputs = prepare_decoupled_bls_cancel_inputs( |
| 491 | + input_value=input_value, |
| 492 | + max_sum_value=max_sum_value, |
| 493 | + ignore_cancel=ignore_cancel, |
| 494 | + ) |
| 495 | + client.async_stream_infer(model_name, inputs) |
| 496 | + |
| 497 | + # Check the results of the decoupled model using BLS |
| 498 | + def check_result(result): |
| 499 | + # Make sure the result is not an exception |
| 500 | + self.assertIsNot(type(result), InferenceServerException) |
| 501 | + is_cancelled = result.as_numpy("IS_CANCELLED") |
| 502 | + self.assertFalse( |
| 503 | + is_cancelled[0], |
| 504 | + "error: expected the request not being cancelled", |
| 505 | + ) |
| 506 | + |
| 507 | + max_sum_data = np.array([max_sum_value], dtype=np.int32) |
| 508 | + sum_data = result.as_numpy("SUM") |
| 509 | + self.assertIsNotNone(sum_data, "error: expected 'SUM'") |
| 510 | + self.assertTrue( |
| 511 | + sum_data < max_sum_data, |
| 512 | + "error: expected sum_data {} to be lesser than max_sum_data {}".format( |
| 513 | + sum_data, max_sum_data |
| 514 | + ), |
| 515 | + ) |
| 516 | + |
| 517 | + result = user_data._completed_requests.get() |
| 518 | + check_result(result) |
| 519 | + |
327 | 520 | def test_decoupled_raise_exception(self):
|
328 | 521 | # The decoupled_raise_exception model raises an exception for the request.
|
329 | 522 | # This test case is making sure that repeated exceptions are properly handled.
|
|
0 commit comments