Skip to content

Commit 42811e0

Browse files
authored
fix: Fix gRPC cancellation race condition (#8078)
1 parent 5fd6bc4 commit 42811e0

File tree

6 files changed

+158
-42
lines changed

6 files changed

+158
-42
lines changed

qa/L0_request_cancellation/grpc_cancellation_test.py

Lines changed: 52 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#!/usr/bin/env python3
22

3-
# Copyright 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3+
# Copyright 2023-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
44
#
55
# Redistribution and use in source and binary forms, with or without
66
# modification, are permitted provided that the following conditions
@@ -202,12 +202,34 @@ def test_grpc_async_infer_response_complete_during_cancellation(self):
202202
) # ensure the cancellation is processed
203203
self._assert_callback_cancelled()
204204

205-
def test_grpc_async_infer_cancellation_during_response_complete(self):
205+
def test_grpc_async_infer_cancellation_before_finish_0(self):
206+
# First version of test_grpc_async_infer_cancellation_before_finish
207+
# Cancellation notification is processed before the final response state.
206208
# long test
207-
self.test_duration_delta = 2.5
209+
self.test_duration_delta = 2
208210
delay_notification_sec = (
209211
int(os.getenv("TRITONSERVER_DELAY_GRPC_NOTIFICATION")) / 1000
210212
)
213+
future = self._client.async_infer(
214+
model_name=self._model_name,
215+
inputs=self._inputs,
216+
callback=self._callback,
217+
outputs=self._outputs,
218+
)
219+
# ensure the cancellation is received between InferResponseComplete checking cancellation and Finish
220+
time.sleep(self._model_delay + 2)
221+
future.cancel()
222+
time.sleep(delay_notification_sec + 1) # ensure the cancellation is processed
223+
self._assert_callback_cancelled()
224+
225+
def test_grpc_async_infer_cancellation_before_finish_1(self):
226+
# Second version of test_grpc_async_infer_cancellation_before_finish
227+
# Cancellation notification is processed after the final response state.
228+
# long test
229+
self.test_duration_delta = 2
230+
delay_process_entry_sec = (
231+
int(os.getenv("TRITONSERVER_DELAY_GRPC_PROCESS_ENTRY")) / 1000
232+
)
211233
delay_response_completion_sec = (
212234
int(os.getenv("TRITONSERVER_DELAY_RESPONSE_COMPLETION")) / 1000
213235
)
@@ -218,13 +240,38 @@ def test_grpc_async_infer_cancellation_during_response_complete(self):
218240
outputs=self._outputs,
219241
)
220242
# ensure the cancellation is received between InferResponseComplete checking cancellation and Finish
221-
time.sleep(self._model_delay + 2)
243+
time.sleep(self._model_delay + delay_process_entry_sec + 2)
222244
future.cancel()
223245
time.sleep(
224-
delay_notification_sec + delay_response_completion_sec
246+
delay_response_completion_sec
225247
) # ensure the cancellation is processed
226248
self._assert_callback_cancelled()
227249

250+
def test_grpc_async_infer_cancellation_before_response_complete_and_process_after_final_response(
251+
self,
252+
):
253+
# Received cancellation before InferResponseComplete and the notification
254+
# state is processed after processing final response state.
255+
# long test
256+
self.test_duration_delta = 2
257+
delay_notification_sec = (
258+
int(os.getenv("TRITONSERVER_DELAY_GRPC_NOTIFICATION")) / 1000
259+
)
260+
delay_response_complete_exec_sec = (
261+
int(os.getenv("TRITONSERVER_DELAY_RESPONSE_COMPLETE_EXEC")) / 1000
262+
)
263+
future = self._client.async_infer(
264+
model_name=self._model_name,
265+
inputs=self._inputs,
266+
callback=self._callback,
267+
outputs=self._outputs,
268+
)
269+
# ensure the cancellation is received before InferResponseComplete checking cancellation
270+
time.sleep(self._model_delay + 2)
271+
future.cancel()
272+
time.sleep(delay_notification_sec + 1) # ensure the cancellation is processed
273+
self._assert_callback_cancelled()
274+
228275

229276
if __name__ == "__main__":
230277
unittest.main()

qa/L0_request_cancellation/test.sh

Lines changed: 36 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
#!/bin/bash
2-
# Copyright 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# Copyright 2023-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
33
#
44
# Redistribution and use in source and binary forms, with or without
55
# modification, are permitted provided that the following conditions
@@ -84,20 +84,28 @@ for TEST_CASE in "test_grpc_async_infer" \
8484
"test_aio_grpc_stream_infer" \
8585
"test_grpc_async_infer_cancellation_at_step_start" \
8686
"test_grpc_async_infer_response_complete_during_cancellation" \
87-
"test_grpc_async_infer_cancellation_during_response_complete"; do
87+
"test_grpc_async_infer_cancellation_before_finish_0" \
88+
"test_grpc_async_infer_cancellation_before_finish_1" \
89+
"test_grpc_async_infer_cancellation_before_response_complete_and_process_after_final_response"; do
8890
TEST_LOG="./grpc_cancellation_test.$TEST_CASE.log"
8991
SERVER_LOG="grpc_cancellation_test.$TEST_CASE.server.log"
9092
if [ "$TEST_CASE" == "test_grpc_async_infer_cancellation_at_step_start" ]; then
9193
export TRITONSERVER_DELAY_GRPC_PROCESS=5000
9294
elif [ "$TEST_CASE" == "test_grpc_async_infer_response_complete_during_cancellation" ]; then
9395
export TRITONSERVER_DELAY_GRPC_NOTIFICATION=5000
9496
export TRITONSERVER_DELAY_GRPC_ENQUEUE=5000
95-
elif [ "$TEST_CASE" == "test_grpc_async_infer_cancellation_during_response_complete" ]; then
97+
elif [ "$TEST_CASE" == "test_grpc_async_infer_cancellation_before_finish_0" ]; then
9698
export TRITONSERVER_DELAY_GRPC_NOTIFICATION=5000
9799
export TRITONSERVER_DELAY_RESPONSE_COMPLETION=5000
100+
elif [ "$TEST_CASE" == "test_grpc_async_infer_cancellation_before_finish_1" ]; then
101+
export TRITONSERVER_DELAY_GRPC_PROCESS_ENTRY=1000
102+
export TRITONSERVER_DELAY_RESPONSE_COMPLETION=5000
103+
elif [ "$TEST_CASE" == "test_grpc_async_infer_cancellation_before_response_complete_and_process_after_final_response" ]; then
104+
export TRITONSERVER_DELAY_GRPC_NOTIFICATION=5000
105+
export TRITONSERVER_DELAY_RESPONSE_COMPLETE_EXEC=5000
98106
fi
99107

100-
SERVER_ARGS="--model-repository=`pwd`/models --log-verbose=1"
108+
SERVER_ARGS="--model-repository=`pwd`/models --log-verbose=2"
101109
run_server
102110
if [ "$SERVER_PID" == "0" ]; then
103111
echo -e "\n***\n*** Failed to start $SERVER\n***"
@@ -123,6 +131,23 @@ for TEST_CASE in "test_grpc_async_infer" \
123131
cat $SERVER_LOG
124132
RET=1
125133
fi
134+
135+
# Tests "test_grpc_async_infer" and "test_aio_grpc_async_infer" ends
136+
# prematurely before state is released.
137+
if [[ "$TEST_CASE" != "test_grpc_async_infer" && "$TEST_CASE" != "test_aio_grpc_async_infer" ]]; then
138+
count=$(grep -o "StateRelease" $SERVER_LOG | wc -l)
139+
state_released=${state_released:=1}
140+
if [ $count == 0 ]; then
141+
echo -e "\n***\n*** State not released by server on $TEST_CASE\n***"
142+
cat $SERVER_LOG
143+
RET=1
144+
elif [ $count -ne $state_released ]; then
145+
echo -e "\n***\n*** Unexpected states released by server on $TEST_CASE. Expected $state_released but released $count.\n***"
146+
cat $SERVER_LOG
147+
RET=1
148+
fi
149+
unset state_released
150+
fi
126151
set -e
127152

128153
kill $SERVER_PID
@@ -133,9 +158,15 @@ for TEST_CASE in "test_grpc_async_infer" \
133158
elif [ "$TEST_CASE" == "test_grpc_async_infer_response_complete_during_cancellation" ]; then
134159
unset TRITONSERVER_DELAY_GRPC_NOTIFICATION
135160
unset TRITONSERVER_DELAY_GRPC_ENQUEUE
136-
elif [ "$TEST_CASE" == "test_grpc_async_infer_cancellation_during_response_complete" ]; then
161+
elif [ "$TEST_CASE" == "test_grpc_async_infer_cancellation_before_finish_0" ]; then
137162
unset TRITONSERVER_DELAY_GRPC_NOTIFICATION
138163
unset TRITONSERVER_DELAY_RESPONSE_COMPLETION
164+
elif [ "$TEST_CASE" == "test_grpc_async_infer_cancellation_before_finish_1" ]; then
165+
unset TRITONSERVER_DELAY_GRPC_PROCESS_ENTRY
166+
unset TRITONSERVER_DELAY_RESPONSE_COMPLETION
167+
elif [ "$TEST_CASE" == "test_grpc_async_infer_cancellation_before_response_complete_and_process_after_final_response" ]; then
168+
unset TRITONSERVER_DELAY_GRPC_NOTIFICATION
169+
unset TRITONSERVER_DELAY_RESPONSE_COMPLETE_EXEC
139170
fi
140171
done
141172

src/grpc/infer_handler.cc

Lines changed: 34 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1+
// Copyright 2023-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
//
33
// Redistribution and use in source and binary forms, with or without
44
// modification, are permitted provided that the following conditions
@@ -697,9 +697,23 @@ ModelInferHandler::Process(
697697
std::chrono::milliseconds(state->delay_process_ms_));
698698
}
699699

700+
if (is_notification) {
701+
state->context_->SetReceivedNotification(true);
702+
}
703+
700704
// Handle notification for cancellation which can be raised
701705
// asynchronously if detected on the network.
702706
if (state->IsGrpcContextCancelled()) {
707+
if (is_notification) {
708+
// Received the cancellation notification
709+
LOG_VERBOSE(1) << "Cancellation notification received for " << Name()
710+
<< ", rpc_ok=" << rpc_ok << ", context "
711+
<< state->context_->unique_id_ << " step "
712+
<< state->context_->step_ << ", state "
713+
<< state->unique_id_ << " step " << state->step_;
714+
}
715+
716+
bool skip_handle_cancellation = false;
703717
if (rpc_ok && (state->step_ == Steps::START) &&
704718
(state->context_->step_ != Steps::CANCELLED)) {
705719
#ifdef TRITON_ENABLE_TRACING
@@ -715,10 +729,16 @@ ModelInferHandler::Process(
715729
// thread, and cancellation at step START was not reproducible in a
716730
// single thread scenario.
717731
StartNewRequest();
732+
} else if (
733+
state->step_ == Steps::COMPLETE || state->step_ == Steps::FINISH) {
734+
// If the request is completed, simply ignore the cancellation.
735+
skip_handle_cancellation = true;
736+
}
737+
738+
if (!skip_handle_cancellation) {
739+
bool resume = state->context_->HandleCancellation(state, rpc_ok, Name());
740+
return resume;
718741
}
719-
bool resume = state->context_->HandleCancellation(
720-
state, rpc_ok, Name(), is_notification);
721-
return resume;
722742
}
723743

724744

@@ -1023,6 +1043,16 @@ ModelInferHandler::InferResponseComplete(
10231043
// notification.
10241044
std::lock_guard<std::recursive_mutex> lock(state->step_mtx_);
10251045

1046+
if (state->delay_response_complete_exec_ms_ != 0) {
1047+
// Will delay the Process execution of state at step ISSUED by the
1048+
// specified time. This can be used to test the flow when cancellation
1049+
// request issued for the request before InferResponseComplete.
1050+
LOG_INFO << "Delaying InferResponseComplete execution by "
1051+
<< state->delay_response_complete_exec_ms_ << " ms...";
1052+
std::this_thread::sleep_for(
1053+
std::chrono::milliseconds(state->delay_response_complete_exec_ms_));
1054+
}
1055+
10261056
// Increment the callback index if received valid 'iresponse'
10271057
if (iresponse != nullptr) {
10281058
state->cb_count_++;

src/grpc/infer_handler.h

Lines changed: 21 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -730,7 +730,7 @@ class InferHandlerState {
730730
ctx_->AsyncNotifyWhenDone(notify_state_.get());
731731
}
732732

733-
void SetReceivedNotification(bool value) { received_notification_ = true; }
733+
void SetReceivedNotification(bool value) { received_notification_ = value; }
734734

735735
bool ReceivedNotification() { return received_notification_; }
736736

@@ -860,7 +860,8 @@ class InferHandlerState {
860860
std::lock_guard<std::recursive_mutex> lock(state->step_mtx_);
861861
if (state->step_ != Steps::CANCELLED &&
862862
state->step_ != Steps::COMPLETE) {
863-
LOG_VERBOSE(1) << "Issuing cancellation for " << state->unique_id_;
863+
LOG_VERBOSE(1) << "Issuing cancellation for " << state->unique_id_
864+
<< " step " << state->step_;
864865
if (state->inference_request_.get() == nullptr) {
865866
// The context might be holding some states that have
866867
// not been issued to Triton core. Need to skip calling
@@ -895,8 +896,7 @@ class InferHandlerState {
895896
// Returns whether or not to continue cycling through the gRPC
896897
// completion queue or not.
897898
bool HandleCancellation(
898-
InferHandlerStateType* state, bool rpc_ok, const std::string& name,
899-
bool is_notification)
899+
InferHandlerStateType* state, bool rpc_ok, const std::string& name)
900900
{
901901
// Check to avoid early exit in case of triton_grpc_error
902902
if (!IsCancelled()) {
@@ -908,12 +908,6 @@ class InferHandlerState {
908908
<< " step " << state->step_;
909909
return true;
910910
}
911-
if (is_notification) {
912-
LOG_VERBOSE(1) << "Cancellation notification received for " << name
913-
<< ", rpc_ok=" << rpc_ok << ", context "
914-
<< state->context_->unique_id_ << ", "
915-
<< state->unique_id_ << " step " << state->step_;
916-
}
917911

918912
if (state->step_ != Steps::CANCELLATION_ISSUED) {
919913
// If the context has not been cancelled then
@@ -934,18 +928,6 @@ class InferHandlerState {
934928
// next iteration from the completion queue which
935929
// would release the state.
936930
return true;
937-
} else if (is_notification && state->step_ == Steps::CANCELLED) {
938-
// A corner case where InferResponseComplete is called between the
939-
// cancellation reception but before the cancellation notification
940-
// thread enters Process function.
941-
// Should let the InferResponseComplete callback trigger the state
942-
// release.
943-
LOG_VERBOSE(1) << "Waiting for the state enqueued by callback to "
944-
"complete cancellation for "
945-
<< name << ", rpc_ok=" << rpc_ok << ", context "
946-
<< state->context_->unique_id_ << ", "
947-
<< state->unique_id_ << " step " << state->step_;
948-
return true;
949931
} else {
950932
// The cancellation request has been handled so the state can be
951933
// released.
@@ -1140,8 +1122,12 @@ class InferHandlerState {
11401122
delay_response_ms_ = ParseDebugVariable("TRITONSERVER_DELAY_GRPC_RESPONSE");
11411123
delay_complete_ms_ = ParseDebugVariable("TRITONSERVER_DELAY_GRPC_COMPLETE");
11421124
delay_process_ms_ = ParseDebugVariable("TRITONSERVER_DELAY_GRPC_PROCESS");
1125+
delay_process_entry_ms_ =
1126+
ParseDebugVariable("TRITONSERVER_DELAY_GRPC_PROCESS_ENTRY");
11431127
delay_notification_process_entry_ms_ =
11441128
ParseDebugVariable("TRITONSERVER_DELAY_GRPC_NOTIFICATION");
1129+
delay_response_complete_exec_ms_ =
1130+
ParseDebugVariable("TRITONSERVER_DELAY_RESPONSE_COMPLETE_EXEC");
11451131
delay_enqueue_ms_ = ParseDebugVariable("TRITONSERVER_DELAY_GRPC_ENQUEUE");
11461132
delay_response_completion_ms_ =
11471133
ParseDebugVariable("TRITONSERVER_DELAY_RESPONSE_COMPLETION");
@@ -1269,7 +1255,9 @@ class InferHandlerState {
12691255
int delay_response_ms_;
12701256
int delay_complete_ms_;
12711257
int delay_process_ms_;
1258+
int delay_process_entry_ms_;
12721259
int delay_notification_process_entry_ms_;
1260+
int delay_response_complete_exec_ms_;
12731261
int delay_enqueue_ms_;
12741262
int delay_response_completion_ms_;
12751263

@@ -1503,7 +1491,6 @@ InferHandler<
15031491
if (state->step_ == Steps::WAITING_NOTIFICATION) {
15041492
State* state_wrapper = state;
15051493
state = state_wrapper->state_ptr_;
1506-
state->context_->SetReceivedNotification(true);
15071494
is_notification = true;
15081495
LOG_VERBOSE(1) << "Received notification for " << Name() << ", "
15091496
<< state->unique_id_;
@@ -1522,14 +1509,25 @@ InferHandler<
15221509
std::this_thread::sleep_for(std::chrono::milliseconds(
15231510
state->delay_notification_process_entry_ms_));
15241511
}
1512+
} else {
1513+
if (state->delay_process_entry_ms_ != 0) {
1514+
// Will delay the entry to Process by the specified time.
1515+
LOG_INFO << "Delaying the entry to Process thread by "
1516+
<< state->delay_process_entry_ms_ << " ms...";
1517+
std::this_thread::sleep_for(
1518+
std::chrono::milliseconds(state->delay_process_entry_ms_));
1519+
}
15251520
}
1521+
15261522
LOG_VERBOSE(2) << "Grpc::CQ::Next() "
15271523
<< state->context_->DebugString(state);
15281524
if (!Process(state, ok, is_notification)) {
15291525
LOG_VERBOSE(1) << "Done for " << Name() << ", " << state->unique_id_;
15301526
state->context_->EraseState(state);
15311527
StateRelease(state);
15321528
} else {
1529+
// In non-streaming infer mode which has multiple request handlers,
1530+
// there is no guarantee state->context_ is valid beyond this line.
15331531
LOG_VERBOSE(2) << "Returning from " << Name() << ", "
15341532
<< state->unique_id_ << ", " << state->step_;
15351533
}

src/grpc/stream_infer_handler.cc

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1+
// Copyright 2023-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
//
33
// Redistribution and use in source and binary forms, with or without
44
// modification, are permitted provided that the following conditions
@@ -133,6 +133,9 @@ bool
133133
ModelStreamInferHandler::Process(
134134
InferHandler::State* state, bool rpc_ok, bool is_notification)
135135
{
136+
if (is_notification) {
137+
state->context_->SetReceivedNotification(true);
138+
}
136139
// Because gRPC doesn't allow concurrent writes on the
137140
// the stream we only have a single handler thread that
138141
// reads from the completion queue. Hence, cancellation
@@ -144,8 +147,16 @@ ModelStreamInferHandler::Process(
144147
if (state->context_->ReceivedNotification()) {
145148
std::lock_guard<std::recursive_mutex> lock(state->step_mtx_);
146149
if (state->IsGrpcContextCancelled()) {
147-
bool resume = state->context_->HandleCancellation(
148-
state, rpc_ok, Name(), is_notification);
150+
if (is_notification) {
151+
// This is the cancellation notification
152+
LOG_VERBOSE(1) << "Cancellation notification received for " << Name()
153+
<< ", rpc_ok=" << rpc_ok << ", context "
154+
<< state->context_->unique_id_ << " step "
155+
<< state->context_->step_ << ", state "
156+
<< state->unique_id_ << " step " << state->step_;
157+
}
158+
159+
bool resume = state->context_->HandleCancellation(state, rpc_ok, Name());
149160
return resume;
150161
} else {
151162
if (state->context_->HandleCompletion()) {

src/grpc/stream_infer_handler.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -106,8 +106,7 @@ class ModelStreamInferHandler
106106

107107
protected:
108108
void StartNewRequest() override;
109-
bool Process(
110-
State* state, bool rpc_ok, bool is_notification = false) override;
109+
bool Process(State* state, bool rpc_ok, bool is_notification) override;
111110

112111
private:
113112
static void StreamInferResponseComplete(

0 commit comments

Comments
 (0)