Skip to content

Commit b354d4d

Browse files
authored
feat: Ensemble async callback execution (rework) (#438)
1 parent 2acb246 commit b354d4d

File tree

5 files changed

+124
-31
lines changed

5 files changed

+124
-31
lines changed

src/constants.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ constexpr char kPythonBackend[] = "python";
6262

6363
#ifdef TRITON_ENABLE_ENSEMBLE
6464
constexpr char kEnsemblePlatform[] = "ensemble";
65+
constexpr uint64_t ENSEMBLE_CB_POOL_SIZE = 8u;
6566
#endif // TRITON_ENABLE_ENSEMBLE
6667

6768
constexpr char kTensorRTExecutionAccelerator[] = "tensorrt";

src/ensemble_scheduler/ensemble_scheduler.cc

Lines changed: 93 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright 2019-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1+
// Copyright 2019-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
//
33
// Redistribution and use in source and binary forms, with or without
44
// modification, are permitted provided that the following conditions
@@ -45,17 +45,42 @@ class EnsembleContext;
4545

4646
using IterationCount = size_t;
4747

48+
// Check if the model is configured to preserve the order of responses.
49+
// This is critical for async execution of ResponseComplete callbacks.
50+
inline bool
51+
preserve_responses_order(const inference::ModelConfig& config)
52+
{
53+
uint64_t total_instance_groups = 0;
54+
for (const auto& group : config.instance_group()) {
55+
total_instance_groups += group.count();
56+
}
57+
58+
// Case 1: Sequence batching is enabled
59+
// Case 2: Dynamic batching is disabled and there is only one instance group
60+
// Case 3: Dynamic batching is enabled and preserve_ordering is true
61+
// Case 4: Model transaction policy is decoupled (if the final response
62+
// callback is not executed in the last step, the RequestTracker object will
63+
// be freed prematurely and led to segmentation fault)
64+
return config.has_sequence_batching() ||
65+
(!config.has_dynamic_batching() && total_instance_groups <= 1) ||
66+
(config.has_dynamic_batching() &&
67+
config.dynamic_batching().preserve_ordering()) ||
68+
config.model_transaction_policy().decoupled();
69+
}
70+
4871
// Request tracker is passed as 'userp' in RequestRelease function and used
4972
// to manage the lifecycle of the ensemble request
5073
class RequestTracker {
5174
public:
5275
explicit RequestTracker(
5376
std::unique_ptr<InferenceRequest>&& request, uint64_t compute_start_ns,
5477
MetricModelReporter* metric_reporter,
55-
InferenceStatsAggregator* stats_aggregator)
78+
InferenceStatsAggregator* stats_aggregator,
79+
triton::common::ThreadPool* callback_pool)
5680
: inflight_request_counter_(1), request_(std::move(request)),
5781
compute_start_ns_(compute_start_ns), metric_reporter_(metric_reporter),
58-
stats_aggregator_(stats_aggregator), status_(Status::Success)
82+
stats_aggregator_(stats_aggregator), status_(Status::Success),
83+
callback_pool_(callback_pool)
5984
{
6085
}
6186

@@ -70,6 +95,8 @@ class RequestTracker {
7095
return context_stats_aggregator_;
7196
}
7297

98+
triton::common::ThreadPool* CallbackPool() const { return callback_pool_; }
99+
73100
void IncrementCounter()
74101
{
75102
std::lock_guard<std::mutex> lk(mtx_);
@@ -120,6 +147,7 @@ class RequestTracker {
120147
InferenceStatsAggregator* stats_aggregator_;
121148
InferenceStatsAggregator context_stats_aggregator_;
122149
Status status_;
150+
triton::common::ThreadPool* const callback_pool_;
123151
};
124152

125153
// Step is used as 'userp' and keeps ensemble context alive
@@ -129,9 +157,9 @@ class RequestTracker {
129157
struct Step {
130158
Step(
131159
size_t step_idx, const InferenceRequest::SequenceId& correlation_id,
132-
uint32_t flags)
160+
uint32_t flags, bool preserve_responses_order)
133161
: correlation_id_(correlation_id), flags_(flags), response_flags_(0),
134-
step_idx_(step_idx)
162+
preserve_responses_order_(preserve_responses_order), step_idx_(step_idx)
135163
{
136164
}
137165

@@ -154,7 +182,7 @@ struct Step {
154182
// returning from the callback.
155183
uint32_t response_flags_;
156184
TRITONSERVER_InferenceResponse* response_;
157-
185+
const bool preserve_responses_order_;
158186

159187
size_t step_idx_;
160188
};
@@ -237,7 +265,7 @@ class EnsembleContext {
237265
MetricModelReporter* metric_reporter,
238266
InferenceStatsAggregator* stats_aggregator, InferenceServer* is,
239267
EnsembleInfo* info, std::unique_ptr<InferenceRequest>& request,
240-
cudaStream_t stream);
268+
cudaStream_t stream, triton::common::ThreadPool* callback_pool);
241269

242270
// Perform transition on 'context' state given the information of
243271
// 'completed_step'
@@ -326,6 +354,8 @@ class EnsembleContext {
326354
void CacheEnsembleTopLevelRequest(
327355
std::unique_ptr<InferenceResponse>& response);
328356

357+
triton::common::ThreadPool* CallbackPool() const { return callback_pool_; }
358+
329359
InferenceServer* is_;
330360

331361
EnsembleInfo* info_;
@@ -375,20 +405,26 @@ class EnsembleContext {
375405
TRITONSERVER_ResponseAllocator,
376406
decltype(&TRITONSERVER_ResponseAllocatorDelete)>
377407
allocator_;
408+
409+
// The thread pool used to execute ensemble callbacks and reduce e2e latency.
410+
// The thread pool is managed by InferenceServer.
411+
triton::common::ThreadPool* const callback_pool_;
378412
};
379413

380414
EnsembleContext::EnsembleContext(
381415
MetricModelReporter* metric_reporter,
382416
InferenceStatsAggregator* stats_aggregator, InferenceServer* is,
383417
EnsembleInfo* info, std::unique_ptr<InferenceRequest>& request,
384-
cudaStream_t stream)
418+
cudaStream_t stream, triton::common::ThreadPool* callback_pool)
385419
: is_(is), info_(info), stream_(stream), inflight_step_counter_(0),
386-
allocator_(nullptr, TRITONSERVER_ResponseAllocatorDelete)
420+
allocator_(nullptr, TRITONSERVER_ResponseAllocatorDelete),
421+
callback_pool_(callback_pool)
387422
{
388423
uint64_t compute_start_ns = 0;
389424
INFER_STATS_SET_TIMESTAMP(compute_start_ns);
390425
request_tracker_ = new RequestTracker(
391-
std::move(request), compute_start_ns, metric_reporter, stats_aggregator);
426+
std::move(request), compute_start_ns, metric_reporter, stats_aggregator,
427+
callback_pool);
392428

393429
auto& lrequest = request_tracker_->Request();
394430

@@ -603,29 +639,57 @@ void
603639
EnsembleContext::RequestComplete(
604640
TRITONSERVER_InferenceRequest* request, const uint32_t flags, void* userp)
605641
{
606-
if ((flags & TRITONSERVER_REQUEST_RELEASE_ALL) != 0) {
607-
LOG_TRITONSERVER_ERROR(
608-
TRITONSERVER_InferenceRequestDelete(request),
609-
"deleting ensemble inference request");
610-
auto request_tracker = reinterpret_cast<RequestTracker*>(userp);
611-
if (request_tracker->DecrementCounter()) {
612-
delete request_tracker;
642+
auto request_tracker = reinterpret_cast<RequestTracker*>(userp);
643+
auto pool = request_tracker->CallbackPool();
644+
auto fn = [request, flags, request_tracker]() {
645+
if ((flags & TRITONSERVER_REQUEST_RELEASE_ALL) != 0) {
646+
LOG_TRITONSERVER_ERROR(
647+
TRITONSERVER_InferenceRequestDelete(request),
648+
"deleting ensemble inference request");
649+
if (request_tracker->DecrementCounter()) {
650+
delete request_tracker;
651+
}
613652
}
653+
};
654+
655+
// Attempt to enqueue the callback. If all workers are busy and queue is at
656+
// capacity, execute the callback immediately in current thread.
657+
if (pool->TaskQueueSize() < pool->Size()) {
658+
pool->Enqueue(fn);
659+
} else {
660+
fn();
614661
}
615662
}
616663

617664
void
618665
EnsembleContext::ResponseComplete(
619666
TRITONSERVER_InferenceResponse* response, const uint32_t flags, void* userp)
620667
{
621-
auto step_ptr = std::unique_ptr<Step>(reinterpret_cast<Step*>(userp));
622-
step_ptr->response_flags_ = flags;
623-
step_ptr->response_ = response;
624-
625-
EnsembleContext::Proceed(step_ptr->ctx_, step_ptr);
626-
// Expecting more responses
627-
if ((flags & TRITONSERVER_RESPONSE_COMPLETE_FINAL) == 0) {
628-
step_ptr.release();
668+
auto step_raw_ptr = reinterpret_cast<Step*>(userp);
669+
auto pool = step_raw_ptr->ctx_->CallbackPool();
670+
auto fn = [response, flags, step_raw_ptr]() {
671+
auto step_ptr = std::unique_ptr<Step>(step_raw_ptr);
672+
step_ptr->response_flags_ = flags;
673+
step_ptr->response_ = response;
674+
675+
EnsembleContext::Proceed(step_ptr->ctx_, step_ptr);
676+
// Expecting more responses
677+
if ((flags & TRITONSERVER_RESPONSE_COMPLETE_FINAL) == 0) {
678+
step_ptr.release();
679+
}
680+
};
681+
682+
// Attempt to enqueue the callback. If all workers are busy and queue is at
683+
// capacity, execute the callback immediately in current thread.
684+
// Note: The async callback optimization does not guarantee the order of
685+
// responses and expolit cases where responses can be out-of-order. For models
686+
// required to preserve the order of responses, the response callbacks must be
687+
// executed in the same thread synchronously.
688+
if (!step_raw_ptr->preserve_responses_order_ &&
689+
pool->TaskQueueSize() < pool->Size()) {
690+
pool->Enqueue(fn);
691+
} else {
692+
fn();
629693
}
630694
}
631695

@@ -971,8 +1035,8 @@ EnsembleContext::InitStep(
9711035
for (const auto& pair : istep.output_to_tensor_) {
9721036
irequest->AddOriginalRequestedOutput(pair.first);
9731037
}
974-
975-
step->reset(new Step(step_idx, correlation_id, flags));
1038+
const bool preserve_order = preserve_responses_order(model->Config());
1039+
step->reset(new Step(step_idx, correlation_id, flags, preserve_order));
9761040

9771041
irequest->SetId(request_id_);
9781042
irequest->SetCorrelationId(correlation_id);
@@ -1448,7 +1512,7 @@ EnsembleScheduler::Enqueue(std::unique_ptr<InferenceRequest>& request)
14481512
RETURN_IF_ERROR(request->SetState(InferenceRequest::State::EXECUTING));
14491513
std::shared_ptr<EnsembleContext> context(new EnsembleContext(
14501514
metric_reporter_.get(), stats_aggregator_, is_, info_.get(), request,
1451-
stream_));
1515+
stream_, callback_pool_));
14521516
EnsembleContext::Proceed(context);
14531517
return Status::Success;
14541518
}
@@ -1537,6 +1601,7 @@ EnsembleScheduler::EnsembleScheduler(
15371601
info_->tensor_to_prev_step_.emplace(pair.second, step_idx);
15381602
}
15391603
}
1604+
callback_pool_ = is_->EnsembleCallbackPool();
15401605
}
15411606

15421607
EnsembleScheduler::~EnsembleScheduler()

src/ensemble_scheduler/ensemble_scheduler.h

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright 2019-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1+
// Copyright 2019-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
//
33
// Redistribution and use in source and binary forms, with or without
44
// modification, are permitted provided that the following conditions
@@ -36,6 +36,7 @@
3636
#include "scheduler.h"
3737
#include "scheduler_utils.h"
3838
#include "status.h"
39+
#include "triton/common/thread_pool.h"
3940

4041
#ifdef TRITON_ENABLE_GPU
4142
#include <cuda_runtime_api.h>
@@ -107,6 +108,8 @@ class EnsembleScheduler : public Scheduler {
107108
// \see Scheduler::Stop()
108109
void Stop() override {}
109110

111+
triton::common::ThreadPool* CallbackPool() const { return callback_pool_; }
112+
110113
private:
111114
EnsembleScheduler(
112115
InferenceStatsAggregator* const stats_aggregator,
@@ -128,6 +131,10 @@ class EnsembleScheduler : public Scheduler {
128131
cudaStream_t stream_;
129132

130133
std::atomic<size_t> inflight_count_;
134+
135+
// Fixed-size thread pool to run callbacks at end of each ensemble step.
136+
// Managed by the server.
137+
triton::common::ThreadPool* callback_pool_;
131138
};
132139

133140
}} // namespace triton::core

src/server.cc

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright 2018-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1+
// Copyright 2018-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
//
33
// Redistribution and use in source and binary forms, with or without
44
// modification, are permitted provided that the following conditions
@@ -117,6 +117,13 @@ InferenceServer::InferenceServer()
117117
#endif // TRITON_ENABLE_GPU
118118

119119
inflight_request_counter_ = 0;
120+
121+
#ifdef TRITON_ENABLE_ENSEMBLE
122+
// TODO: Need to scale the thread pool size smarter, e.g. based on the
123+
// instance_group count of composing models.
124+
ensemble_cb_pool_.reset(
125+
new triton::common::ThreadPool(ENSEMBLE_CB_POOL_SIZE));
126+
#endif
120127
}
121128

122129
Status

src/server.h

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright 2018-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1+
// Copyright 2018-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
//
33
// Redistribution and use in source and binary forms, with or without
44
// modification, are permitted provided that the following conditions
@@ -332,6 +332,13 @@ class InferenceServer {
332332
return cache_manager_;
333333
}
334334

335+
#ifdef TRITON_ENABLE_ENSEMBLE
336+
triton::common::ThreadPool* EnsembleCallbackPool() const
337+
{
338+
return ensemble_cb_pool_.get();
339+
}
340+
#endif // TRITON_ENABLE_ENSEMBLE
341+
335342
private:
336343
const std::string version_;
337344
std::string id_;
@@ -375,6 +382,12 @@ class InferenceServer {
375382
std::unique_ptr<ModelRepositoryManager> model_repository_manager_;
376383
std::shared_ptr<TritonBackendManager> backend_manager_;
377384
std::shared_ptr<TritonCacheManager> cache_manager_;
385+
386+
#ifdef TRITON_ENABLE_ENSEMBLE
387+
// The thread pool for all ensemble models to execute callbacks
388+
// asynchronously.
389+
std::unique_ptr<triton::common::ThreadPool> ensemble_cb_pool_;
390+
#endif // TRITON_ENABLE_ENSEMBLE
378391
};
379392

380393
}} // namespace triton::core

0 commit comments

Comments
 (0)