1
- // Copyright 2019-2025 , NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1
+ // Copyright 2019-2024 , NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
2
//
3
3
// Redistribution and use in source and binary forms, with or without
4
4
// modification, are permitted provided that the following conditions
@@ -52,12 +52,10 @@ class RequestTracker {
52
52
explicit RequestTracker (
53
53
std::unique_ptr<InferenceRequest>&& request, uint64_t compute_start_ns,
54
54
MetricModelReporter* metric_reporter,
55
- InferenceStatsAggregator* stats_aggregator,
56
- triton::common::ThreadPool* callback_pool)
55
+ InferenceStatsAggregator* stats_aggregator)
57
56
: inflight_request_counter_(1 ), request_(std::move(request)),
58
57
compute_start_ns_(compute_start_ns), metric_reporter_(metric_reporter),
59
- stats_aggregator_(stats_aggregator), status_(Status::Success),
60
- callback_pool_(callback_pool)
58
+ stats_aggregator_(stats_aggregator), status_(Status::Success)
61
59
{
62
60
}
63
61
@@ -72,8 +70,6 @@ class RequestTracker {
72
70
return context_stats_aggregator_;
73
71
}
74
72
75
- triton::common::ThreadPool* CallbackPool () const { return callback_pool_; }
76
-
77
73
void IncrementCounter ()
78
74
{
79
75
std::lock_guard<std::mutex> lk (mtx_);
@@ -124,7 +120,6 @@ class RequestTracker {
124
120
InferenceStatsAggregator* stats_aggregator_;
125
121
InferenceStatsAggregator context_stats_aggregator_;
126
122
Status status_;
127
- triton::common::ThreadPool* const callback_pool_;
128
123
};
129
124
130
125
// Step is used as 'userp' and keeps ensemble context alive
@@ -242,7 +237,7 @@ class EnsembleContext {
242
237
MetricModelReporter* metric_reporter,
243
238
InferenceStatsAggregator* stats_aggregator, InferenceServer* is,
244
239
EnsembleInfo* info, std::unique_ptr<InferenceRequest>& request,
245
- cudaStream_t stream, triton::common::ThreadPool* callback_pool );
240
+ cudaStream_t stream);
246
241
247
242
// Perform transition on 'context' state given the information of
248
243
// 'completed_step'
@@ -331,8 +326,6 @@ class EnsembleContext {
331
326
void CacheEnsembleTopLevelRequest (
332
327
std::unique_ptr<InferenceResponse>& response);
333
328
334
- triton::common::ThreadPool* CallbackPool () const { return callback_pool_; }
335
-
336
329
InferenceServer* is_;
337
330
338
331
EnsembleInfo* info_;
@@ -382,26 +375,20 @@ class EnsembleContext {
382
375
TRITONSERVER_ResponseAllocator,
383
376
decltype (&TRITONSERVER_ResponseAllocatorDelete)>
384
377
allocator_;
385
-
386
- // The thread pool used to execute ensemble callbacks and reduce e2e latency.
387
- // The thread pool is managed by InferenceServer.
388
- triton::common::ThreadPool* const callback_pool_;
389
378
};
390
379
391
380
EnsembleContext::EnsembleContext (
392
381
MetricModelReporter* metric_reporter,
393
382
InferenceStatsAggregator* stats_aggregator, InferenceServer* is,
394
383
EnsembleInfo* info, std::unique_ptr<InferenceRequest>& request,
395
- cudaStream_t stream, triton::common::ThreadPool* callback_pool )
384
+ cudaStream_t stream)
396
385
: is_(is), info_(info), stream_(stream), inflight_step_counter_(0 ),
397
- allocator_(nullptr , TRITONSERVER_ResponseAllocatorDelete),
398
- callback_pool_(callback_pool)
386
+ allocator_(nullptr , TRITONSERVER_ResponseAllocatorDelete)
399
387
{
400
388
uint64_t compute_start_ns = 0 ;
401
389
INFER_STATS_SET_TIMESTAMP (compute_start_ns);
402
390
request_tracker_ = new RequestTracker (
403
- std::move (request), compute_start_ns, metric_reporter, stats_aggregator,
404
- callback_pool);
391
+ std::move (request), compute_start_ns, metric_reporter, stats_aggregator);
405
392
406
393
auto & lrequest = request_tracker_->Request ();
407
394
@@ -616,52 +603,29 @@ void
616
603
EnsembleContext::RequestComplete (
617
604
TRITONSERVER_InferenceRequest* request, const uint32_t flags, void * userp)
618
605
{
619
- auto request_tracker = reinterpret_cast <RequestTracker*>(userp);
620
- auto pool = request_tracker->CallbackPool ();
621
- auto fn = [request, flags, request_tracker]() {
622
- if ((flags & TRITONSERVER_REQUEST_RELEASE_ALL) != 0 ) {
623
- LOG_TRITONSERVER_ERROR (
624
- TRITONSERVER_InferenceRequestDelete (request),
625
- " deleting ensemble inference request" );
626
- if (request_tracker->DecrementCounter ()) {
627
- delete request_tracker;
628
- }
606
+ if ((flags & TRITONSERVER_REQUEST_RELEASE_ALL) != 0 ) {
607
+ LOG_TRITONSERVER_ERROR (
608
+ TRITONSERVER_InferenceRequestDelete (request),
609
+ " deleting ensemble inference request" );
610
+ auto request_tracker = reinterpret_cast <RequestTracker*>(userp);
611
+ if (request_tracker->DecrementCounter ()) {
612
+ delete request_tracker;
629
613
}
630
- };
631
-
632
- // Attempt to enqueue the callback. If all workers are busy and queue is at
633
- // capacity, execute the callback immediately.
634
- if (pool->TaskQueueSize () < pool->Size ()) {
635
- pool->Enqueue (fn);
636
- } else {
637
- fn ();
638
614
}
639
615
}
640
616
641
617
void
642
618
EnsembleContext::ResponseComplete (
643
619
TRITONSERVER_InferenceResponse* response, const uint32_t flags, void * userp)
644
620
{
645
- auto step_raw_ptr = reinterpret_cast <Step*>(userp);
646
- auto pool = step_raw_ptr->ctx_ ->CallbackPool ();
647
- auto fn = [response, flags, step_raw_ptr]() {
648
- auto step_ptr = std::unique_ptr<Step>(step_raw_ptr);
649
- step_ptr->response_flags_ = flags;
650
- step_ptr->response_ = response;
651
-
652
- EnsembleContext::Proceed (step_ptr->ctx_ , step_ptr);
653
- // Expecting more responses
654
- if ((flags & TRITONSERVER_RESPONSE_COMPLETE_FINAL) == 0 ) {
655
- step_ptr.release ();
656
- }
657
- };
658
-
659
- // Attempt to enqueue the callback. If all workers are busy and queue is at
660
- // capacity, execute the callback immediately.
661
- if (pool->TaskQueueSize () < pool->Size ()) {
662
- pool->Enqueue (fn);
663
- } else {
664
- fn ();
621
+ auto step_ptr = std::unique_ptr<Step>(reinterpret_cast <Step*>(userp));
622
+ step_ptr->response_flags_ = flags;
623
+ step_ptr->response_ = response;
624
+
625
+ EnsembleContext::Proceed (step_ptr->ctx_ , step_ptr);
626
+ // Expecting more responses
627
+ if ((flags & TRITONSERVER_RESPONSE_COMPLETE_FINAL) == 0 ) {
628
+ step_ptr.release ();
665
629
}
666
630
}
667
631
@@ -1484,7 +1448,7 @@ EnsembleScheduler::Enqueue(std::unique_ptr<InferenceRequest>& request)
1484
1448
RETURN_IF_ERROR (request->SetState (InferenceRequest::State::EXECUTING));
1485
1449
std::shared_ptr<EnsembleContext> context (new EnsembleContext (
1486
1450
metric_reporter_.get (), stats_aggregator_, is_, info_.get (), request,
1487
- stream_, callback_pool_ ));
1451
+ stream_));
1488
1452
EnsembleContext::Proceed (context);
1489
1453
return Status::Success;
1490
1454
}
@@ -1573,7 +1537,6 @@ EnsembleScheduler::EnsembleScheduler(
1573
1537
info_->tensor_to_prev_step_ .emplace (pair.second , step_idx);
1574
1538
}
1575
1539
}
1576
- callback_pool_ = is_->EnsembleCallbackPool ();
1577
1540
}
1578
1541
1579
1542
EnsembleScheduler::~EnsembleScheduler ()
0 commit comments