1
- // Copyright 2019-2024 , NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1
+ // Copyright 2019-2025 , NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
2
//
3
3
// Redistribution and use in source and binary forms, with or without
4
4
// modification, are permitted provided that the following conditions
@@ -52,10 +52,12 @@ class RequestTracker {
52
52
explicit RequestTracker (
53
53
std::unique_ptr<InferenceRequest>&& request, uint64_t compute_start_ns,
54
54
MetricModelReporter* metric_reporter,
55
- InferenceStatsAggregator* stats_aggregator)
55
+ InferenceStatsAggregator* stats_aggregator,
56
+ triton::common::ThreadPool* callback_pool)
56
57
: inflight_request_counter_(1 ), request_(std::move(request)),
57
58
compute_start_ns_(compute_start_ns), metric_reporter_(metric_reporter),
58
- stats_aggregator_(stats_aggregator), status_(Status::Success)
59
+ stats_aggregator_(stats_aggregator), status_(Status::Success),
60
+ callback_pool_(callback_pool)
59
61
{
60
62
}
61
63
@@ -70,6 +72,8 @@ class RequestTracker {
70
72
return context_stats_aggregator_;
71
73
}
72
74
75
+ triton::common::ThreadPool* CallbackPool () const { return callback_pool_; }
76
+
73
77
void IncrementCounter ()
74
78
{
75
79
std::lock_guard<std::mutex> lk (mtx_);
@@ -120,6 +124,7 @@ class RequestTracker {
120
124
InferenceStatsAggregator* stats_aggregator_;
121
125
InferenceStatsAggregator context_stats_aggregator_;
122
126
Status status_;
127
+ triton::common::ThreadPool* const callback_pool_;
123
128
};
124
129
125
130
// Step is used as 'userp' and keeps ensemble context alive
@@ -237,7 +242,7 @@ class EnsembleContext {
237
242
MetricModelReporter* metric_reporter,
238
243
InferenceStatsAggregator* stats_aggregator, InferenceServer* is,
239
244
EnsembleInfo* info, std::unique_ptr<InferenceRequest>& request,
240
- cudaStream_t stream);
245
+ cudaStream_t stream, triton::common::ThreadPool* callback_pool );
241
246
242
247
// Perform transition on 'context' state given the information of
243
248
// 'completed_step'
@@ -326,6 +331,8 @@ class EnsembleContext {
326
331
void CacheEnsembleTopLevelRequest (
327
332
std::unique_ptr<InferenceResponse>& response);
328
333
334
+ triton::common::ThreadPool* CallbackPool () const { return callback_pool_; }
335
+
329
336
InferenceServer* is_;
330
337
331
338
EnsembleInfo* info_;
@@ -375,20 +382,26 @@ class EnsembleContext {
375
382
TRITONSERVER_ResponseAllocator,
376
383
decltype (&TRITONSERVER_ResponseAllocatorDelete)>
377
384
allocator_;
385
+
386
+ // The thread pool used to execute ensemble callbacks and reduce e2e latency.
387
+ // The thread pool is managed by InferenceServer.
388
+ triton::common::ThreadPool* const callback_pool_;
378
389
};
379
390
380
391
EnsembleContext::EnsembleContext (
381
392
MetricModelReporter* metric_reporter,
382
393
InferenceStatsAggregator* stats_aggregator, InferenceServer* is,
383
394
EnsembleInfo* info, std::unique_ptr<InferenceRequest>& request,
384
- cudaStream_t stream)
395
+ cudaStream_t stream, triton::common::ThreadPool* callback_pool )
385
396
: is_(is), info_(info), stream_(stream), inflight_step_counter_(0 ),
386
- allocator_(nullptr , TRITONSERVER_ResponseAllocatorDelete)
397
+ allocator_(nullptr , TRITONSERVER_ResponseAllocatorDelete),
398
+ callback_pool_(callback_pool)
387
399
{
388
400
uint64_t compute_start_ns = 0 ;
389
401
INFER_STATS_SET_TIMESTAMP (compute_start_ns);
390
402
request_tracker_ = new RequestTracker (
391
- std::move (request), compute_start_ns, metric_reporter, stats_aggregator);
403
+ std::move (request), compute_start_ns, metric_reporter, stats_aggregator,
404
+ callback_pool);
392
405
393
406
auto & lrequest = request_tracker_->Request ();
394
407
@@ -603,29 +616,52 @@ void
603
616
EnsembleContext::RequestComplete (
604
617
TRITONSERVER_InferenceRequest* request, const uint32_t flags, void * userp)
605
618
{
606
- if ((flags & TRITONSERVER_REQUEST_RELEASE_ALL) != 0 ) {
607
- LOG_TRITONSERVER_ERROR (
608
- TRITONSERVER_InferenceRequestDelete (request),
609
- " deleting ensemble inference request" );
610
- auto request_tracker = reinterpret_cast <RequestTracker*>(userp);
611
- if (request_tracker->DecrementCounter ()) {
612
- delete request_tracker;
619
+ auto request_tracker = reinterpret_cast <RequestTracker*>(userp);
620
+ auto pool = request_tracker->CallbackPool ();
621
+ auto fn = [request, flags, request_tracker]() {
622
+ if ((flags & TRITONSERVER_REQUEST_RELEASE_ALL) != 0 ) {
623
+ LOG_TRITONSERVER_ERROR (
624
+ TRITONSERVER_InferenceRequestDelete (request),
625
+ " deleting ensemble inference request" );
626
+ if (request_tracker->DecrementCounter ()) {
627
+ delete request_tracker;
628
+ }
613
629
}
630
+ };
631
+
632
+ // Attempt to enqueue the callback. If all workers are busy and queue is at
633
+ // capacity, execute the callback immediately.
634
+ if (pool->TaskQueueSize () < pool->Size ()) {
635
+ pool->Enqueue (fn);
636
+ } else {
637
+ fn ();
614
638
}
615
639
}
616
640
617
641
void
618
642
EnsembleContext::ResponseComplete (
619
643
TRITONSERVER_InferenceResponse* response, const uint32_t flags, void * userp)
620
644
{
621
- auto step_ptr = std::unique_ptr<Step>(reinterpret_cast <Step*>(userp));
622
- step_ptr->response_flags_ = flags;
623
- step_ptr->response_ = response;
624
-
625
- EnsembleContext::Proceed (step_ptr->ctx_ , step_ptr);
626
- // Expecting more responses
627
- if ((flags & TRITONSERVER_RESPONSE_COMPLETE_FINAL) == 0 ) {
628
- step_ptr.release ();
645
+ auto step_raw_ptr = reinterpret_cast <Step*>(userp);
646
+ auto pool = step_raw_ptr->ctx_ ->CallbackPool ();
647
+ auto fn = [response, flags, step_raw_ptr]() {
648
+ auto step_ptr = std::unique_ptr<Step>(step_raw_ptr);
649
+ step_ptr->response_flags_ = flags;
650
+ step_ptr->response_ = response;
651
+
652
+ EnsembleContext::Proceed (step_ptr->ctx_ , step_ptr);
653
+ // Expecting more responses
654
+ if ((flags & TRITONSERVER_RESPONSE_COMPLETE_FINAL) == 0 ) {
655
+ step_ptr.release ();
656
+ }
657
+ };
658
+
659
+ // Attempt to enqueue the callback. If all workers are busy and queue is at
660
+ // capacity, execute the callback immediately.
661
+ if (pool->TaskQueueSize () < pool->Size ()) {
662
+ pool->Enqueue (fn);
663
+ } else {
664
+ fn ();
629
665
}
630
666
}
631
667
@@ -1448,7 +1484,7 @@ EnsembleScheduler::Enqueue(std::unique_ptr<InferenceRequest>& request)
1448
1484
RETURN_IF_ERROR (request->SetState (InferenceRequest::State::EXECUTING));
1449
1485
std::shared_ptr<EnsembleContext> context (new EnsembleContext (
1450
1486
metric_reporter_.get (), stats_aggregator_, is_, info_.get (), request,
1451
- stream_));
1487
+ stream_, callback_pool_ ));
1452
1488
EnsembleContext::Proceed (context);
1453
1489
return Status::Success;
1454
1490
}
@@ -1537,6 +1573,7 @@ EnsembleScheduler::EnsembleScheduler(
1537
1573
info_->tensor_to_prev_step_ .emplace (pair.second , step_idx);
1538
1574
}
1539
1575
}
1576
+ callback_pool_ = is_->EnsembleCallbackPool ();
1540
1577
}
1541
1578
1542
1579
EnsembleScheduler::~EnsembleScheduler ()
0 commit comments