Skip to content

Commit 109c69f

Browse files
Revert "feat: Ensemble asynchronous callback executions (#429)" (#436)
This reverts commit 56e97eb.
1 parent 6fe7384 commit 109c69f

File tree

5 files changed

+26
-91
lines changed

5 files changed

+26
-91
lines changed

src/constants.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,6 @@ constexpr char kPythonBackend[] = "python";
6262

6363
#ifdef TRITON_ENABLE_ENSEMBLE
6464
constexpr char kEnsemblePlatform[] = "ensemble";
65-
constexpr uint64_t ENSEMBLE_CB_POOL_SIZE = 8u;
6665
#endif // TRITON_ENABLE_ENSEMBLE
6766

6867
constexpr char kTensorRTExecutionAccelerator[] = "tensorrt";

src/ensemble_scheduler/ensemble_scheduler.cc

Lines changed: 23 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright 2019-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1+
// Copyright 2019-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
//
33
// Redistribution and use in source and binary forms, with or without
44
// modification, are permitted provided that the following conditions
@@ -52,12 +52,10 @@ class RequestTracker {
5252
explicit RequestTracker(
5353
std::unique_ptr<InferenceRequest>&& request, uint64_t compute_start_ns,
5454
MetricModelReporter* metric_reporter,
55-
InferenceStatsAggregator* stats_aggregator,
56-
triton::common::ThreadPool* callback_pool)
55+
InferenceStatsAggregator* stats_aggregator)
5756
: inflight_request_counter_(1), request_(std::move(request)),
5857
compute_start_ns_(compute_start_ns), metric_reporter_(metric_reporter),
59-
stats_aggregator_(stats_aggregator), status_(Status::Success),
60-
callback_pool_(callback_pool)
58+
stats_aggregator_(stats_aggregator), status_(Status::Success)
6159
{
6260
}
6361

@@ -72,8 +70,6 @@ class RequestTracker {
7270
return context_stats_aggregator_;
7371
}
7472

75-
triton::common::ThreadPool* CallbackPool() const { return callback_pool_; }
76-
7773
void IncrementCounter()
7874
{
7975
std::lock_guard<std::mutex> lk(mtx_);
@@ -124,7 +120,6 @@ class RequestTracker {
124120
InferenceStatsAggregator* stats_aggregator_;
125121
InferenceStatsAggregator context_stats_aggregator_;
126122
Status status_;
127-
triton::common::ThreadPool* const callback_pool_;
128123
};
129124

130125
// Step is used as 'userp' and keeps ensemble context alive
@@ -242,7 +237,7 @@ class EnsembleContext {
242237
MetricModelReporter* metric_reporter,
243238
InferenceStatsAggregator* stats_aggregator, InferenceServer* is,
244239
EnsembleInfo* info, std::unique_ptr<InferenceRequest>& request,
245-
cudaStream_t stream, triton::common::ThreadPool* callback_pool);
240+
cudaStream_t stream);
246241

247242
// Perform transition on 'context' state given the information of
248243
// 'completed_step'
@@ -331,8 +326,6 @@ class EnsembleContext {
331326
void CacheEnsembleTopLevelRequest(
332327
std::unique_ptr<InferenceResponse>& response);
333328

334-
triton::common::ThreadPool* CallbackPool() const { return callback_pool_; }
335-
336329
InferenceServer* is_;
337330

338331
EnsembleInfo* info_;
@@ -382,26 +375,20 @@ class EnsembleContext {
382375
TRITONSERVER_ResponseAllocator,
383376
decltype(&TRITONSERVER_ResponseAllocatorDelete)>
384377
allocator_;
385-
386-
// The thread pool used to execute ensemble callbacks and reduce e2e latency.
387-
// The thread pool is managed by InferenceServer.
388-
triton::common::ThreadPool* const callback_pool_;
389378
};
390379

391380
EnsembleContext::EnsembleContext(
392381
MetricModelReporter* metric_reporter,
393382
InferenceStatsAggregator* stats_aggregator, InferenceServer* is,
394383
EnsembleInfo* info, std::unique_ptr<InferenceRequest>& request,
395-
cudaStream_t stream, triton::common::ThreadPool* callback_pool)
384+
cudaStream_t stream)
396385
: is_(is), info_(info), stream_(stream), inflight_step_counter_(0),
397-
allocator_(nullptr, TRITONSERVER_ResponseAllocatorDelete),
398-
callback_pool_(callback_pool)
386+
allocator_(nullptr, TRITONSERVER_ResponseAllocatorDelete)
399387
{
400388
uint64_t compute_start_ns = 0;
401389
INFER_STATS_SET_TIMESTAMP(compute_start_ns);
402390
request_tracker_ = new RequestTracker(
403-
std::move(request), compute_start_ns, metric_reporter, stats_aggregator,
404-
callback_pool);
391+
std::move(request), compute_start_ns, metric_reporter, stats_aggregator);
405392

406393
auto& lrequest = request_tracker_->Request();
407394

@@ -616,52 +603,29 @@ void
616603
EnsembleContext::RequestComplete(
617604
TRITONSERVER_InferenceRequest* request, const uint32_t flags, void* userp)
618605
{
619-
auto request_tracker = reinterpret_cast<RequestTracker*>(userp);
620-
auto pool = request_tracker->CallbackPool();
621-
auto fn = [request, flags, request_tracker]() {
622-
if ((flags & TRITONSERVER_REQUEST_RELEASE_ALL) != 0) {
623-
LOG_TRITONSERVER_ERROR(
624-
TRITONSERVER_InferenceRequestDelete(request),
625-
"deleting ensemble inference request");
626-
if (request_tracker->DecrementCounter()) {
627-
delete request_tracker;
628-
}
606+
if ((flags & TRITONSERVER_REQUEST_RELEASE_ALL) != 0) {
607+
LOG_TRITONSERVER_ERROR(
608+
TRITONSERVER_InferenceRequestDelete(request),
609+
"deleting ensemble inference request");
610+
auto request_tracker = reinterpret_cast<RequestTracker*>(userp);
611+
if (request_tracker->DecrementCounter()) {
612+
delete request_tracker;
629613
}
630-
};
631-
632-
// Attempt to enqueue the callback. If all workers are busy and queue is at
633-
// capacity, execute the callback immediately.
634-
if (pool->TaskQueueSize() < pool->Size()) {
635-
pool->Enqueue(fn);
636-
} else {
637-
fn();
638614
}
639615
}
640616

641617
void
642618
EnsembleContext::ResponseComplete(
643619
TRITONSERVER_InferenceResponse* response, const uint32_t flags, void* userp)
644620
{
645-
auto step_raw_ptr = reinterpret_cast<Step*>(userp);
646-
auto pool = step_raw_ptr->ctx_->CallbackPool();
647-
auto fn = [response, flags, step_raw_ptr]() {
648-
auto step_ptr = std::unique_ptr<Step>(step_raw_ptr);
649-
step_ptr->response_flags_ = flags;
650-
step_ptr->response_ = response;
651-
652-
EnsembleContext::Proceed(step_ptr->ctx_, step_ptr);
653-
// Expecting more responses
654-
if ((flags & TRITONSERVER_RESPONSE_COMPLETE_FINAL) == 0) {
655-
step_ptr.release();
656-
}
657-
};
658-
659-
// Attempt to enqueue the callback. If all workers are busy and queue is at
660-
// capacity, execute the callback immediately.
661-
if (pool->TaskQueueSize() < pool->Size()) {
662-
pool->Enqueue(fn);
663-
} else {
664-
fn();
621+
auto step_ptr = std::unique_ptr<Step>(reinterpret_cast<Step*>(userp));
622+
step_ptr->response_flags_ = flags;
623+
step_ptr->response_ = response;
624+
625+
EnsembleContext::Proceed(step_ptr->ctx_, step_ptr);
626+
// Expecting more responses
627+
if ((flags & TRITONSERVER_RESPONSE_COMPLETE_FINAL) == 0) {
628+
step_ptr.release();
665629
}
666630
}
667631

@@ -1484,7 +1448,7 @@ EnsembleScheduler::Enqueue(std::unique_ptr<InferenceRequest>& request)
14841448
RETURN_IF_ERROR(request->SetState(InferenceRequest::State::EXECUTING));
14851449
std::shared_ptr<EnsembleContext> context(new EnsembleContext(
14861450
metric_reporter_.get(), stats_aggregator_, is_, info_.get(), request,
1487-
stream_, callback_pool_));
1451+
stream_));
14881452
EnsembleContext::Proceed(context);
14891453
return Status::Success;
14901454
}
@@ -1573,7 +1537,6 @@ EnsembleScheduler::EnsembleScheduler(
15731537
info_->tensor_to_prev_step_.emplace(pair.second, step_idx);
15741538
}
15751539
}
1576-
callback_pool_ = is_->EnsembleCallbackPool();
15771540
}
15781541

15791542
EnsembleScheduler::~EnsembleScheduler()

src/ensemble_scheduler/ensemble_scheduler.h

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright 2019-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1+
// Copyright 2019-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
//
33
// Redistribution and use in source and binary forms, with or without
44
// modification, are permitted provided that the following conditions
@@ -36,7 +36,6 @@
3636
#include "scheduler.h"
3737
#include "scheduler_utils.h"
3838
#include "status.h"
39-
#include "triton/common/thread_pool.h"
4039

4140
#ifdef TRITON_ENABLE_GPU
4241
#include <cuda_runtime_api.h>
@@ -108,8 +107,6 @@ class EnsembleScheduler : public Scheduler {
108107
// \see Scheduler::Stop()
109108
void Stop() override {}
110109

111-
triton::common::ThreadPool* CallbackPool() const { return callback_pool_; }
112-
113110
private:
114111
EnsembleScheduler(
115112
InferenceStatsAggregator* const stats_aggregator,
@@ -131,10 +128,6 @@ class EnsembleScheduler : public Scheduler {
131128
cudaStream_t stream_;
132129

133130
std::atomic<size_t> inflight_count_;
134-
135-
// Fixed-size thread pool to run callbacks at end of each ensemble step.
136-
// Managed by the server.
137-
triton::common::ThreadPool* callback_pool_;
138131
};
139132

140133
}} // namespace triton::core

src/server.cc

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright 2018-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1+
// Copyright 2018-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
//
33
// Redistribution and use in source and binary forms, with or without
44
// modification, are permitted provided that the following conditions
@@ -117,13 +117,6 @@ InferenceServer::InferenceServer()
117117
#endif // TRITON_ENABLE_GPU
118118

119119
inflight_request_counter_ = 0;
120-
121-
#ifdef TRITON_ENABLE_ENSEMBLE
122-
// TODO: Need to scale the thread pool size smarter, e.g. based on the
123-
// instance_group count of composing models.
124-
ensemble_cb_pool_.reset(
125-
new triton::common::ThreadPool(ENSEMBLE_CB_POOL_SIZE));
126-
#endif
127120
}
128121

129122
Status

src/server.h

Lines changed: 1 addition & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright 2018-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1+
// Copyright 2018-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
//
33
// Redistribution and use in source and binary forms, with or without
44
// modification, are permitted provided that the following conditions
@@ -332,13 +332,6 @@ class InferenceServer {
332332
return cache_manager_;
333333
}
334334

335-
#ifdef TRITON_ENABLE_ENSEMBLE
336-
triton::common::ThreadPool* EnsembleCallbackPool() const
337-
{
338-
return ensemble_cb_pool_.get();
339-
}
340-
#endif // TRITON_ENABLE_ENSEMBLE
341-
342335
private:
343336
const std::string version_;
344337
std::string id_;
@@ -382,12 +375,6 @@ class InferenceServer {
382375
std::unique_ptr<ModelRepositoryManager> model_repository_manager_;
383376
std::shared_ptr<TritonBackendManager> backend_manager_;
384377
std::shared_ptr<TritonCacheManager> cache_manager_;
385-
386-
#ifdef TRITON_ENABLE_ENSEMBLE
387-
// The thread pool for all ensemble models to execute callbacks
388-
// asynchronously.
389-
std::unique_ptr<triton::common::ThreadPool> ensemble_cb_pool_;
390-
#endif // TRITON_ENABLE_ENSEMBLE
391378
};
392379

393380
}} // namespace triton::core

0 commit comments

Comments
 (0)