Skip to content

Commit 1dcf5bb

Browse files
authored
Fix state transitions for re-running requests (#251)
1 parent c6c45ff commit 1dcf5bb

File tree

3 files changed

+38
-49
lines changed

3 files changed

+38
-49
lines changed

src/infer_request.cc

Lines changed: 37 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -106,86 +106,81 @@ InferenceRequest::InferenceRequest(
106106
: needs_normalization_(true), model_raw_(model),
107107
requested_model_version_(requested_model_version), flags_(0),
108108
correlation_id_(0), batch_size_(0), timeout_us_(0), collect_stats_(true),
109-
state_(InferenceRequest::State::INITIALIZED), null_request_(false),
110-
decrement_pending_count_(false)
109+
state_(InferenceRequest::State::INITIALIZED), null_request_(false)
111110
{
112111
SetPriority(0);
113112
}
114113

115-
InferenceRequest::~InferenceRequest()
116-
{
117-
// If request has been enqueued but hasn't started executing by destruction
118-
// time, an error occurred and the pending request count will need to be
119-
// decremented.
120-
DecrementPendingRequestCount();
121-
}
122-
123-
124114
Status
125115
InferenceRequest::SetState(InferenceRequest::State new_state)
126116
{
117+
LOG_VERBOSE(1) << LogRequest() << "Setting state from " << state_ << " to "
118+
<< new_state;
127119
// No-op if this is already the current state, or if this is a null request.
128120
if (new_state == state_ || null_request_) {
129121
return Status::Success;
130122
}
131123

132-
// Allow RELEASED state transition from any state for now.
133-
// Not all requests will follow linear transition, such as null requests
134-
// used for padding batches, and ensemble requests.
135-
if (new_state == InferenceRequest::State::RELEASED) {
136-
state_ = new_state;
137-
return Status::Success;
138-
}
139-
140124
// Generate error when called rather than copying it into every case below.
141125
const auto generate_error = [&]() {
142126
std::stringstream ss;
143127
ss << LogRequest() << "Invalid request state transition from " << state_
144128
<< " to " << new_state;
145-
return Status(Status::Code::INVALID_ARG, ss.str());
129+
return Status(Status::Code::INTERNAL, ss.str());
146130
};
147131

148132
// Define state transitions
149133
switch (state_) {
150134
case InferenceRequest::State::INITIALIZED: {
151-
if (new_state != InferenceRequest::State::STARTED) {
135+
if (new_state == InferenceRequest::State::PENDING) {
136+
IncrementPendingRequestCount();
137+
} else if (new_state == InferenceRequest::State::RELEASED) {
138+
// No-op when moving from initialized to released, just releasing early.
139+
} else {
152140
return generate_error();
153141
}
154-
state_ = new_state;
155-
IncrementPendingRequestCount();
156142
break;
157143
}
158-
case InferenceRequest::State::STARTED: {
159-
if (new_state != InferenceRequest::State::EXECUTING) {
144+
case InferenceRequest::State::PENDING: {
145+
// Request may move from pending to either execution when scheduled to
146+
// backend, or released early due to some error.
147+
if (new_state == InferenceRequest::State::EXECUTING ||
148+
new_state == InferenceRequest::State::RELEASED) {
149+
DecrementPendingRequestCount();
150+
} else {
151+
// Unexpected state transition
160152
return generate_error();
161153
}
162-
state_ = new_state;
163-
DecrementPendingRequestCount();
164154
break;
165155
}
166156
case InferenceRequest::State::EXECUTING: {
167157
if (new_state != InferenceRequest::State::RELEASED) {
168158
return generate_error();
169159
}
170-
state_ = new_state;
171160
break;
172161
}
173162
case InferenceRequest::State::RELEASED: {
174-
// No state transition currently supported after release.
175-
return generate_error();
163+
if (new_state != InferenceRequest::State::INITIALIZED) {
164+
// Only transition currently supported after release is to start over
165+
// again, such as re-using request objects for multiple inferences.
166+
return generate_error();
167+
}
168+
break;
176169
}
177170
}
171+
state_ = new_state;
178172
return Status::Success;
179173
}
180174

181175
void
182176
InferenceRequest::IncrementPendingRequestCount()
183177
{
184178
#ifdef TRITON_ENABLE_METRICS
179+
// Pending request count should always be 0 or 1 per-request. If a request
180+
// increments the count, it should not be incremented again until decremented.
185181
auto reporter = model_raw_->MetricReporter();
186182
if (reporter) {
187183
reporter->IncrementGauge(kPendingRequestMetric, 1);
188-
decrement_pending_count_ = true;
189184
}
190185
#endif // TRITON_ENABLE_METRICS
191186
}
@@ -194,13 +189,11 @@ void
194189
InferenceRequest::DecrementPendingRequestCount()
195190
{
196191
#ifdef TRITON_ENABLE_METRICS
197-
// Only decrement if count has been incremented, and not already decremented.
198-
if (decrement_pending_count_) {
199-
auto reporter = model_raw_->MetricReporter();
200-
if (reporter) {
201-
reporter->DecrementGauge(kPendingRequestMetric, 1);
202-
}
203-
decrement_pending_count_ = false;
192+
// Pending request count should always be 0 or 1 per-request. A request should
193+
// not decrement the count unless it has already been incremented.
194+
auto reporter = model_raw_->MetricReporter();
195+
if (reporter) {
196+
reporter->DecrementGauge(kPendingRequestMetric, 1);
204197
}
205198
#endif // TRITON_ENABLE_METRICS
206199
}
@@ -376,7 +369,7 @@ InferenceRequest::OutputBufferProperties(
376369
Status
377370
InferenceRequest::Run(std::unique_ptr<InferenceRequest>& request)
378371
{
379-
RETURN_IF_ERROR(request->SetState(InferenceRequest::State::STARTED));
372+
RETURN_IF_ERROR(request->SetState(InferenceRequest::State::PENDING));
380373
return request->model_raw_->Enqueue(request);
381374
}
382375

@@ -849,8 +842,10 @@ InferenceRequest::PrepareForInference()
849842
request_start_ns_ = 0;
850843
#endif // TRITON_ENABLE_STATS
851844

852-
LOG_VERBOSE(1) << LogRequest() << "prepared: " << *this;
845+
// Help enforce that PrepareForInference() is called prior to Run().
846+
RETURN_IF_ERROR(SetState(InferenceRequest::State::INITIALIZED));
853847

848+
LOG_VERBOSE(1) << LogRequest() << "prepared: " << *this;
854849
return Status::Success;
855850
}
856851

@@ -1580,8 +1575,8 @@ operator<<(std::ostream& out, const InferenceRequest::State& state)
15801575
out << "INITIALIZED";
15811576
break;
15821577
}
1583-
case InferenceRequest::State::STARTED: {
1584-
out << "STARTED";
1578+
case InferenceRequest::State::PENDING: {
1579+
out << "PENDING";
15851580
break;
15861581
}
15871582
case InferenceRequest::State::EXECUTING: {

src/infer_request.h

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ class InferenceRequest {
6363
INITIALIZED,
6464

6565
// The request has been enqueued, but is not yet executing.
66-
STARTED,
66+
PENDING,
6767

6868
// The request has been picked up by a backend model instance for execution,
6969
// but hasn't been released yet.
@@ -291,7 +291,6 @@ class InferenceRequest {
291291
const int64_t requested_model_version);
292292

293293
InferenceRequest(Model* model, const int64_t requested_model_version);
294-
~InferenceRequest();
295294

296295
const std::string& ModelName() const;
297296
int64_t RequestedModelVersion() const { return requested_model_version_; }
@@ -799,9 +798,6 @@ class InferenceRequest {
799798
// Whether this is a null request used for direct sequence batch padding or
800799
// not.
801800
bool null_request_;
802-
// Catch-all to correctly decrement pending count if needed on destruction
803-
// if request doesn't follow normal execution path (error, unused, ensembles)
804-
bool decrement_pending_count_;
805801
};
806802

807803
std::ostream& operator<<(std::ostream& out, const InferenceRequest& request);

src/test/response_cache_test.cc

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,6 @@ InferenceRequest::InferenceRequest(
6666
response_factory_.reset(new InferenceResponseFactory());
6767
}
6868

69-
InferenceRequest::~InferenceRequest() {}
70-
7169
InferenceRequest::Input::Input(
7270
const std::string& name, const inference::DataType datatype,
7371
const int64_t* shape, const uint64_t dim_count)

0 commit comments

Comments
 (0)