Skip to content

Commit c01d4cc

Browse files
committed
Revert "Revert "feat: Ensemble asynchronous callback executions (#429)" (#436)"
This reverts commit 109c69f.
1 parent 109c69f commit c01d4cc

File tree

5 files changed

+91
-26
lines changed

5 files changed

+91
-26
lines changed

src/constants.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ constexpr char kPythonBackend[] = "python";
6262

6363
#ifdef TRITON_ENABLE_ENSEMBLE
6464
constexpr char kEnsemblePlatform[] = "ensemble";
65+
constexpr uint64_t ENSEMBLE_CB_POOL_SIZE = 8u;
6566
#endif // TRITON_ENABLE_ENSEMBLE
6667

6768
constexpr char kTensorRTExecutionAccelerator[] = "tensorrt";

src/ensemble_scheduler/ensemble_scheduler.cc

Lines changed: 60 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright 2019-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1+
// Copyright 2019-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
//
33
// Redistribution and use in source and binary forms, with or without
44
// modification, are permitted provided that the following conditions
@@ -52,10 +52,12 @@ class RequestTracker {
5252
explicit RequestTracker(
5353
std::unique_ptr<InferenceRequest>&& request, uint64_t compute_start_ns,
5454
MetricModelReporter* metric_reporter,
55-
InferenceStatsAggregator* stats_aggregator)
55+
InferenceStatsAggregator* stats_aggregator,
56+
triton::common::ThreadPool* callback_pool)
5657
: inflight_request_counter_(1), request_(std::move(request)),
5758
compute_start_ns_(compute_start_ns), metric_reporter_(metric_reporter),
58-
stats_aggregator_(stats_aggregator), status_(Status::Success)
59+
stats_aggregator_(stats_aggregator), status_(Status::Success),
60+
callback_pool_(callback_pool)
5961
{
6062
}
6163

@@ -70,6 +72,8 @@ class RequestTracker {
7072
return context_stats_aggregator_;
7173
}
7274

75+
triton::common::ThreadPool* CallbackPool() const { return callback_pool_; }
76+
7377
void IncrementCounter()
7478
{
7579
std::lock_guard<std::mutex> lk(mtx_);
@@ -120,6 +124,7 @@ class RequestTracker {
120124
InferenceStatsAggregator* stats_aggregator_;
121125
InferenceStatsAggregator context_stats_aggregator_;
122126
Status status_;
127+
triton::common::ThreadPool* const callback_pool_;
123128
};
124129

125130
// Step is used as 'userp' and keeps ensemble context alive
@@ -237,7 +242,7 @@ class EnsembleContext {
237242
MetricModelReporter* metric_reporter,
238243
InferenceStatsAggregator* stats_aggregator, InferenceServer* is,
239244
EnsembleInfo* info, std::unique_ptr<InferenceRequest>& request,
240-
cudaStream_t stream);
245+
cudaStream_t stream, triton::common::ThreadPool* callback_pool);
241246

242247
// Perform transition on 'context' state given the information of
243248
// 'completed_step'
@@ -326,6 +331,8 @@ class EnsembleContext {
326331
void CacheEnsembleTopLevelRequest(
327332
std::unique_ptr<InferenceResponse>& response);
328333

334+
triton::common::ThreadPool* CallbackPool() const { return callback_pool_; }
335+
329336
InferenceServer* is_;
330337

331338
EnsembleInfo* info_;
@@ -375,20 +382,26 @@ class EnsembleContext {
375382
TRITONSERVER_ResponseAllocator,
376383
decltype(&TRITONSERVER_ResponseAllocatorDelete)>
377384
allocator_;
385+
386+
// The thread pool used to execute ensemble callbacks and reduce e2e latency.
387+
// The thread pool is managed by InferenceServer.
388+
triton::common::ThreadPool* const callback_pool_;
378389
};
379390

380391
EnsembleContext::EnsembleContext(
381392
MetricModelReporter* metric_reporter,
382393
InferenceStatsAggregator* stats_aggregator, InferenceServer* is,
383394
EnsembleInfo* info, std::unique_ptr<InferenceRequest>& request,
384-
cudaStream_t stream)
395+
cudaStream_t stream, triton::common::ThreadPool* callback_pool)
385396
: is_(is), info_(info), stream_(stream), inflight_step_counter_(0),
386-
allocator_(nullptr, TRITONSERVER_ResponseAllocatorDelete)
397+
allocator_(nullptr, TRITONSERVER_ResponseAllocatorDelete),
398+
callback_pool_(callback_pool)
387399
{
388400
uint64_t compute_start_ns = 0;
389401
INFER_STATS_SET_TIMESTAMP(compute_start_ns);
390402
request_tracker_ = new RequestTracker(
391-
std::move(request), compute_start_ns, metric_reporter, stats_aggregator);
403+
std::move(request), compute_start_ns, metric_reporter, stats_aggregator,
404+
callback_pool);
392405

393406
auto& lrequest = request_tracker_->Request();
394407

@@ -603,29 +616,52 @@ void
603616
EnsembleContext::RequestComplete(
604617
TRITONSERVER_InferenceRequest* request, const uint32_t flags, void* userp)
605618
{
606-
if ((flags & TRITONSERVER_REQUEST_RELEASE_ALL) != 0) {
607-
LOG_TRITONSERVER_ERROR(
608-
TRITONSERVER_InferenceRequestDelete(request),
609-
"deleting ensemble inference request");
610-
auto request_tracker = reinterpret_cast<RequestTracker*>(userp);
611-
if (request_tracker->DecrementCounter()) {
612-
delete request_tracker;
619+
auto request_tracker = reinterpret_cast<RequestTracker*>(userp);
620+
auto pool = request_tracker->CallbackPool();
621+
auto fn = [request, flags, request_tracker]() {
622+
if ((flags & TRITONSERVER_REQUEST_RELEASE_ALL) != 0) {
623+
LOG_TRITONSERVER_ERROR(
624+
TRITONSERVER_InferenceRequestDelete(request),
625+
"deleting ensemble inference request");
626+
if (request_tracker->DecrementCounter()) {
627+
delete request_tracker;
628+
}
613629
}
630+
};
631+
632+
// Attempt to enqueue the callback. If all workers are busy and queue is at
633+
// capacity, execute the callback immediately.
634+
if (pool->TaskQueueSize() < pool->Size()) {
635+
pool->Enqueue(fn);
636+
} else {
637+
fn();
614638
}
615639
}
616640

617641
void
618642
EnsembleContext::ResponseComplete(
619643
TRITONSERVER_InferenceResponse* response, const uint32_t flags, void* userp)
620644
{
621-
auto step_ptr = std::unique_ptr<Step>(reinterpret_cast<Step*>(userp));
622-
step_ptr->response_flags_ = flags;
623-
step_ptr->response_ = response;
624-
625-
EnsembleContext::Proceed(step_ptr->ctx_, step_ptr);
626-
// Expecting more responses
627-
if ((flags & TRITONSERVER_RESPONSE_COMPLETE_FINAL) == 0) {
628-
step_ptr.release();
645+
auto step_raw_ptr = reinterpret_cast<Step*>(userp);
646+
auto pool = step_raw_ptr->ctx_->CallbackPool();
647+
auto fn = [response, flags, step_raw_ptr]() {
648+
auto step_ptr = std::unique_ptr<Step>(step_raw_ptr);
649+
step_ptr->response_flags_ = flags;
650+
step_ptr->response_ = response;
651+
652+
EnsembleContext::Proceed(step_ptr->ctx_, step_ptr);
653+
// Expecting more responses
654+
if ((flags & TRITONSERVER_RESPONSE_COMPLETE_FINAL) == 0) {
655+
step_ptr.release();
656+
}
657+
};
658+
659+
// Attempt to enqueue the callback. If all workers are busy and queue is at
660+
// capacity, execute the callback immediately.
661+
if (pool->TaskQueueSize() < pool->Size()) {
662+
pool->Enqueue(fn);
663+
} else {
664+
fn();
629665
}
630666
}
631667

@@ -1448,7 +1484,7 @@ EnsembleScheduler::Enqueue(std::unique_ptr<InferenceRequest>& request)
14481484
RETURN_IF_ERROR(request->SetState(InferenceRequest::State::EXECUTING));
14491485
std::shared_ptr<EnsembleContext> context(new EnsembleContext(
14501486
metric_reporter_.get(), stats_aggregator_, is_, info_.get(), request,
1451-
stream_));
1487+
stream_, callback_pool_));
14521488
EnsembleContext::Proceed(context);
14531489
return Status::Success;
14541490
}
@@ -1537,6 +1573,7 @@ EnsembleScheduler::EnsembleScheduler(
15371573
info_->tensor_to_prev_step_.emplace(pair.second, step_idx);
15381574
}
15391575
}
1576+
callback_pool_ = is_->EnsembleCallbackPool();
15401577
}
15411578

15421579
EnsembleScheduler::~EnsembleScheduler()

src/ensemble_scheduler/ensemble_scheduler.h

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright 2019-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1+
// Copyright 2019-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
//
33
// Redistribution and use in source and binary forms, with or without
44
// modification, are permitted provided that the following conditions
@@ -36,6 +36,7 @@
3636
#include "scheduler.h"
3737
#include "scheduler_utils.h"
3838
#include "status.h"
39+
#include "triton/common/thread_pool.h"
3940

4041
#ifdef TRITON_ENABLE_GPU
4142
#include <cuda_runtime_api.h>
@@ -107,6 +108,8 @@ class EnsembleScheduler : public Scheduler {
107108
// \see Scheduler::Stop()
108109
void Stop() override {}
109110

111+
triton::common::ThreadPool* CallbackPool() const { return callback_pool_; }
112+
110113
private:
111114
EnsembleScheduler(
112115
InferenceStatsAggregator* const stats_aggregator,
@@ -128,6 +131,10 @@ class EnsembleScheduler : public Scheduler {
128131
cudaStream_t stream_;
129132

130133
std::atomic<size_t> inflight_count_;
134+
135+
// Fixed-size thread pool to run callbacks at end of each ensemble step.
136+
// Managed by the server.
137+
triton::common::ThreadPool* callback_pool_;
131138
};
132139

133140
}} // namespace triton::core

src/server.cc

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright 2018-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1+
// Copyright 2018-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
//
33
// Redistribution and use in source and binary forms, with or without
44
// modification, are permitted provided that the following conditions
@@ -117,6 +117,13 @@ InferenceServer::InferenceServer()
117117
#endif // TRITON_ENABLE_GPU
118118

119119
inflight_request_counter_ = 0;
120+
121+
#ifdef TRITON_ENABLE_ENSEMBLE
122+
// TODO: Need to scale the thread pool size smarter, e.g. based on the
123+
// instance_group count of composing models.
124+
ensemble_cb_pool_.reset(
125+
new triton::common::ThreadPool(ENSEMBLE_CB_POOL_SIZE));
126+
#endif
120127
}
121128

122129
Status

src/server.h

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright 2018-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1+
// Copyright 2018-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
//
33
// Redistribution and use in source and binary forms, with or without
44
// modification, are permitted provided that the following conditions
@@ -332,6 +332,13 @@ class InferenceServer {
332332
return cache_manager_;
333333
}
334334

335+
#ifdef TRITON_ENABLE_ENSEMBLE
336+
triton::common::ThreadPool* EnsembleCallbackPool() const
337+
{
338+
return ensemble_cb_pool_.get();
339+
}
340+
#endif // TRITON_ENABLE_ENSEMBLE
341+
335342
private:
336343
const std::string version_;
337344
std::string id_;
@@ -375,6 +382,12 @@ class InferenceServer {
375382
std::unique_ptr<ModelRepositoryManager> model_repository_manager_;
376383
std::shared_ptr<TritonBackendManager> backend_manager_;
377384
std::shared_ptr<TritonCacheManager> cache_manager_;
385+
386+
#ifdef TRITON_ENABLE_ENSEMBLE
387+
// The thread pool for all ensemble models to execute callbacks
388+
// asynchronously.
389+
std::unique_ptr<triton::common::ThreadPool> ensemble_cb_pool_;
390+
#endif // TRITON_ENABLE_ENSEMBLE
378391
};
379392

380393
}} // namespace triton::core

0 commit comments

Comments
 (0)