1
- // Copyright 2019-2024 , NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1
+ // Copyright 2019-2025 , NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
2
//
3
3
// Redistribution and use in source and binary forms, with or without
4
4
// modification, are permitted provided that the following conditions
@@ -45,17 +45,42 @@ class EnsembleContext;
45
45
46
46
using IterationCount = size_t ;
47
47
48
+ // Check if the model is configured to preserve the order of responses.
49
+ // This is critical for async execution of ResponseComplete callbacks.
50
+ inline bool
51
+ preserve_responses_order (const inference::ModelConfig& config)
52
+ {
53
+ uint64_t total_instance_groups = 0 ;
54
+ for (const auto & group : config.instance_group ()) {
55
+ total_instance_groups += group.count ();
56
+ }
57
+
58
+ // Case 1: Sequence batching is enabled
59
+ // Case 2: Dynamic batching is disabled and there is only one instance group
60
+ // Case 3: Dynamic batching is enabled and preserve_ordering is true
61
+ // Case 4: Model transaction policy is decoupled (if the final response
62
+ // callback is not executed in the last step, the RequestTracker object will
63
+ // be freed prematurely and led to segmentation fault)
64
+ return config.has_sequence_batching () ||
65
+ (!config.has_dynamic_batching () && total_instance_groups <= 1 ) ||
66
+ (config.has_dynamic_batching () &&
67
+ config.dynamic_batching ().preserve_ordering ()) ||
68
+ config.model_transaction_policy ().decoupled ();
69
+ }
70
+
48
71
// Request tracker is passed as 'userp' in RequestRelease function and used
49
72
// to manage the lifecycle of the ensemble request
50
73
class RequestTracker {
51
74
public:
52
75
explicit RequestTracker (
53
76
std::unique_ptr<InferenceRequest>&& request, uint64_t compute_start_ns,
54
77
MetricModelReporter* metric_reporter,
55
- InferenceStatsAggregator* stats_aggregator)
78
+ InferenceStatsAggregator* stats_aggregator,
79
+ triton::common::ThreadPool* callback_pool)
56
80
: inflight_request_counter_(1 ), request_(std::move(request)),
57
81
compute_start_ns_(compute_start_ns), metric_reporter_(metric_reporter),
58
- stats_aggregator_(stats_aggregator), status_(Status::Success)
82
+ stats_aggregator_(stats_aggregator), status_(Status::Success),
83
+ callback_pool_(callback_pool)
59
84
{
60
85
}
61
86
@@ -70,6 +95,8 @@ class RequestTracker {
70
95
return context_stats_aggregator_;
71
96
}
72
97
98
+ triton::common::ThreadPool* CallbackPool () const { return callback_pool_; }
99
+
73
100
void IncrementCounter ()
74
101
{
75
102
std::lock_guard<std::mutex> lk (mtx_);
@@ -120,6 +147,7 @@ class RequestTracker {
120
147
InferenceStatsAggregator* stats_aggregator_;
121
148
InferenceStatsAggregator context_stats_aggregator_;
122
149
Status status_;
150
+ triton::common::ThreadPool* const callback_pool_;
123
151
};
124
152
125
153
// Step is used as 'userp' and keeps ensemble context alive
@@ -129,9 +157,9 @@ class RequestTracker {
129
157
struct Step {
130
158
Step (
131
159
size_t step_idx, const InferenceRequest::SequenceId& correlation_id,
132
- uint32_t flags)
160
+ uint32_t flags, bool preserve_responses_order )
133
161
: correlation_id_(correlation_id), flags_(flags), response_flags_(0 ),
134
- step_idx_ (step_idx)
162
+ preserve_responses_order_ (preserve_responses_order), step_idx_(step_idx)
135
163
{
136
164
}
137
165
@@ -154,7 +182,7 @@ struct Step {
154
182
// returning from the callback.
155
183
uint32_t response_flags_;
156
184
TRITONSERVER_InferenceResponse* response_;
157
-
185
+ const bool preserve_responses_order_;
158
186
159
187
size_t step_idx_;
160
188
};
@@ -237,7 +265,7 @@ class EnsembleContext {
237
265
MetricModelReporter* metric_reporter,
238
266
InferenceStatsAggregator* stats_aggregator, InferenceServer* is,
239
267
EnsembleInfo* info, std::unique_ptr<InferenceRequest>& request,
240
- cudaStream_t stream);
268
+ cudaStream_t stream, triton::common::ThreadPool* callback_pool );
241
269
242
270
// Perform transition on 'context' state given the information of
243
271
// 'completed_step'
@@ -326,6 +354,8 @@ class EnsembleContext {
326
354
void CacheEnsembleTopLevelRequest (
327
355
std::unique_ptr<InferenceResponse>& response);
328
356
357
+ triton::common::ThreadPool* CallbackPool () const { return callback_pool_; }
358
+
329
359
InferenceServer* is_;
330
360
331
361
EnsembleInfo* info_;
@@ -375,20 +405,26 @@ class EnsembleContext {
375
405
TRITONSERVER_ResponseAllocator,
376
406
decltype (&TRITONSERVER_ResponseAllocatorDelete)>
377
407
allocator_;
408
+
409
+ // The thread pool used to execute ensemble callbacks and reduce e2e latency.
410
+ // The thread pool is managed by InferenceServer.
411
+ triton::common::ThreadPool* const callback_pool_;
378
412
};
379
413
380
414
EnsembleContext::EnsembleContext (
381
415
MetricModelReporter* metric_reporter,
382
416
InferenceStatsAggregator* stats_aggregator, InferenceServer* is,
383
417
EnsembleInfo* info, std::unique_ptr<InferenceRequest>& request,
384
- cudaStream_t stream)
418
+ cudaStream_t stream, triton::common::ThreadPool* callback_pool )
385
419
: is_(is), info_(info), stream_(stream), inflight_step_counter_(0 ),
386
- allocator_(nullptr , TRITONSERVER_ResponseAllocatorDelete)
420
+ allocator_(nullptr , TRITONSERVER_ResponseAllocatorDelete),
421
+ callback_pool_(callback_pool)
387
422
{
388
423
uint64_t compute_start_ns = 0 ;
389
424
INFER_STATS_SET_TIMESTAMP (compute_start_ns);
390
425
request_tracker_ = new RequestTracker (
391
- std::move (request), compute_start_ns, metric_reporter, stats_aggregator);
426
+ std::move (request), compute_start_ns, metric_reporter, stats_aggregator,
427
+ callback_pool);
392
428
393
429
auto & lrequest = request_tracker_->Request ();
394
430
@@ -603,29 +639,57 @@ void
603
639
EnsembleContext::RequestComplete (
604
640
TRITONSERVER_InferenceRequest* request, const uint32_t flags, void * userp)
605
641
{
606
- if ((flags & TRITONSERVER_REQUEST_RELEASE_ALL) != 0 ) {
607
- LOG_TRITONSERVER_ERROR (
608
- TRITONSERVER_InferenceRequestDelete (request),
609
- " deleting ensemble inference request" );
610
- auto request_tracker = reinterpret_cast <RequestTracker*>(userp);
611
- if (request_tracker->DecrementCounter ()) {
612
- delete request_tracker;
642
+ auto request_tracker = reinterpret_cast <RequestTracker*>(userp);
643
+ auto pool = request_tracker->CallbackPool ();
644
+ auto fn = [request, flags, request_tracker]() {
645
+ if ((flags & TRITONSERVER_REQUEST_RELEASE_ALL) != 0 ) {
646
+ LOG_TRITONSERVER_ERROR (
647
+ TRITONSERVER_InferenceRequestDelete (request),
648
+ " deleting ensemble inference request" );
649
+ if (request_tracker->DecrementCounter ()) {
650
+ delete request_tracker;
651
+ }
613
652
}
653
+ };
654
+
655
+ // Attempt to enqueue the callback. If all workers are busy and queue is at
656
+ // capacity, execute the callback immediately in current thread.
657
+ if (pool->TaskQueueSize () < pool->Size ()) {
658
+ pool->Enqueue (fn);
659
+ } else {
660
+ fn ();
614
661
}
615
662
}
616
663
617
664
void
618
665
EnsembleContext::ResponseComplete (
619
666
TRITONSERVER_InferenceResponse* response, const uint32_t flags, void * userp)
620
667
{
621
- auto step_ptr = std::unique_ptr<Step>(reinterpret_cast <Step*>(userp));
622
- step_ptr->response_flags_ = flags;
623
- step_ptr->response_ = response;
624
-
625
- EnsembleContext::Proceed (step_ptr->ctx_ , step_ptr);
626
- // Expecting more responses
627
- if ((flags & TRITONSERVER_RESPONSE_COMPLETE_FINAL) == 0 ) {
628
- step_ptr.release ();
668
+ auto step_raw_ptr = reinterpret_cast <Step*>(userp);
669
+ auto pool = step_raw_ptr->ctx_ ->CallbackPool ();
670
+ auto fn = [response, flags, step_raw_ptr]() {
671
+ auto step_ptr = std::unique_ptr<Step>(step_raw_ptr);
672
+ step_ptr->response_flags_ = flags;
673
+ step_ptr->response_ = response;
674
+
675
+ EnsembleContext::Proceed (step_ptr->ctx_ , step_ptr);
676
+ // Expecting more responses
677
+ if ((flags & TRITONSERVER_RESPONSE_COMPLETE_FINAL) == 0 ) {
678
+ step_ptr.release ();
679
+ }
680
+ };
681
+
682
+ // Attempt to enqueue the callback. If all workers are busy and queue is at
683
+ // capacity, execute the callback immediately in current thread.
684
+ // Note: The async callback optimization does not guarantee the order of
685
+ // responses and expolit cases where responses can be out-of-order. For models
686
+ // required to preserve the order of responses, the response callbacks must be
687
+ // executed in the same thread synchronously.
688
+ if (!step_raw_ptr->preserve_responses_order_ &&
689
+ pool->TaskQueueSize () < pool->Size ()) {
690
+ pool->Enqueue (fn);
691
+ } else {
692
+ fn ();
629
693
}
630
694
}
631
695
@@ -971,8 +1035,8 @@ EnsembleContext::InitStep(
971
1035
for (const auto & pair : istep.output_to_tensor_ ) {
972
1036
irequest->AddOriginalRequestedOutput (pair.first );
973
1037
}
974
-
975
- step->reset (new Step (step_idx, correlation_id, flags));
1038
+ const bool preserve_order = preserve_responses_order (model-> Config ());
1039
+ step->reset (new Step (step_idx, correlation_id, flags, preserve_order ));
976
1040
977
1041
irequest->SetId (request_id_);
978
1042
irequest->SetCorrelationId (correlation_id);
@@ -1448,7 +1512,7 @@ EnsembleScheduler::Enqueue(std::unique_ptr<InferenceRequest>& request)
1448
1512
RETURN_IF_ERROR (request->SetState (InferenceRequest::State::EXECUTING));
1449
1513
std::shared_ptr<EnsembleContext> context (new EnsembleContext (
1450
1514
metric_reporter_.get (), stats_aggregator_, is_, info_.get (), request,
1451
- stream_));
1515
+ stream_, callback_pool_ ));
1452
1516
EnsembleContext::Proceed (context);
1453
1517
return Status::Success;
1454
1518
}
@@ -1537,6 +1601,7 @@ EnsembleScheduler::EnsembleScheduler(
1537
1601
info_->tensor_to_prev_step_ .emplace (pair.second , step_idx);
1538
1602
}
1539
1603
}
1604
+ callback_pool_ = is_->EnsembleCallbackPool ();
1540
1605
}
1541
1606
1542
1607
EnsembleScheduler::~EnsembleScheduler ()
0 commit comments