Skip to content

fix: Revert async execution of ensemble model ResponseComplete callback #435

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/constants.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ constexpr char kPythonBackend[] = "python";

#ifdef TRITON_ENABLE_ENSEMBLE
constexpr char kEnsemblePlatform[] = "ensemble";
constexpr uint64_t ENSEMBLE_CB_POOL_SIZE = 8u;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Decrease thread pool size because there will be fewer callbacks running asynchronusly.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you even need this constant?

Copy link
Contributor Author

@yinggeh yinggeh Apr 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. Because RequestComplete callback is still leveraging thread pool to run asynchrously, which will benefit the ensemble model throughput.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

void
EnsembleContext::RequestComplete(
TRITONSERVER_InferenceRequest* request, const uint32_t flags, void* userp)
{
auto request_tracker = reinterpret_cast<RequestTracker*>(userp);
auto pool = request_tracker->CallbackPool();
auto fn = [request, flags, request_tracker]() {
if ((flags & TRITONSERVER_REQUEST_RELEASE_ALL) != 0) {
LOG_TRITONSERVER_ERROR(
TRITONSERVER_InferenceRequestDelete(request),
"deleting ensemble inference request");
if (request_tracker->DecrementCounter()) {
delete request_tracker;
}
}
};
// Attempt to enqueue the callback. If all workers are busy and queue is at
// capacity, execute the callback immediately.
if (pool->TaskQueueSize() < pool->Size()) {
pool->Enqueue(fn);
} else {
fn();
}
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how was this number picked and why this number?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The previous pool size 8 was explained here #429 (comment).
In a non-decoupled ensemble model, this PR will reduce half of async callbacks. Thus I reduce pool size by half to 4.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My concern is how general those experiments were?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you elaborate? If thread pool queue is full, new task will execute synchronously like normal to avoid delay. See

// Attempt to enqueue the callback. If all workers are busy and queue is at
// capacity, execute the callback immediately.
if (pool->TaskQueueSize() < pool->Size()) {
pool->Enqueue(fn);
} else {
fn();
}
}

constexpr uint64_t ENSEMBLE_CB_POOL_SIZE = 4u;
#endif // TRITON_ENABLE_ENSEMBLE

constexpr char kTensorRTExecutionAccelerator[] = "tensorrt";
Expand Down
37 changes: 9 additions & 28 deletions src/ensemble_scheduler/ensemble_scheduler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -331,8 +331,6 @@ class EnsembleContext {
void CacheEnsembleTopLevelRequest(
std::unique_ptr<InferenceResponse>& response);

triton::common::ThreadPool* CallbackPool() const { return callback_pool_; }

InferenceServer* is_;

EnsembleInfo* info_;
Expand Down Expand Up @@ -382,10 +380,6 @@ class EnsembleContext {
TRITONSERVER_ResponseAllocator,
decltype(&TRITONSERVER_ResponseAllocatorDelete)>
allocator_;

// The thread pool used to execute ensemble callbacks and reduce e2e latency.
// The thread pool is managed by InferenceServer.
triton::common::ThreadPool* const callback_pool_;
};

EnsembleContext::EnsembleContext(
Expand All @@ -394,8 +388,7 @@ EnsembleContext::EnsembleContext(
EnsembleInfo* info, std::unique_ptr<InferenceRequest>& request,
cudaStream_t stream, triton::common::ThreadPool* callback_pool)
: is_(is), info_(info), stream_(stream), inflight_step_counter_(0),
allocator_(nullptr, TRITONSERVER_ResponseAllocatorDelete),
callback_pool_(callback_pool)
allocator_(nullptr, TRITONSERVER_ResponseAllocatorDelete)
{
uint64_t compute_start_ns = 0;
INFER_STATS_SET_TIMESTAMP(compute_start_ns);
Expand Down Expand Up @@ -642,26 +635,14 @@ void
EnsembleContext::ResponseComplete(
TRITONSERVER_InferenceResponse* response, const uint32_t flags, void* userp)
{
auto step_raw_ptr = reinterpret_cast<Step*>(userp);
auto pool = step_raw_ptr->ctx_->CallbackPool();
auto fn = [response, flags, step_raw_ptr]() {
auto step_ptr = std::unique_ptr<Step>(step_raw_ptr);
step_ptr->response_flags_ = flags;
step_ptr->response_ = response;

EnsembleContext::Proceed(step_ptr->ctx_, step_ptr);
// Expecting more responses
if ((flags & TRITONSERVER_RESPONSE_COMPLETE_FINAL) == 0) {
step_ptr.release();
}
};

// Attempt to enqueue the callback. If all workers are busy and queue is at
// capacity, execute the callback immediately.
if (pool->TaskQueueSize() < pool->Size()) {
pool->Enqueue(fn);
} else {
fn();
auto step_ptr = std::unique_ptr<Step>(reinterpret_cast<Step*>(userp));
step_ptr->response_flags_ = flags;
step_ptr->response_ = response;

EnsembleContext::Proceed(step_ptr->ctx_, step_ptr);
// Expecting more responses
if ((flags & TRITONSERVER_RESPONSE_COMPLETE_FINAL) == 0) {
step_ptr.release();
}
}

Expand Down