Skip to content

Commit 817aaf4

Browse files
authored
Support top level response caching for ensemble models (#338) (#352)
1 parent 06b8f6e commit 817aaf4

File tree

5 files changed

+144
-33
lines changed

5 files changed

+144
-33
lines changed

src/dynamic_batch_scheduler.cc

Lines changed: 2 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -39,14 +39,6 @@
3939

4040
namespace triton { namespace core {
4141

42-
uint64_t
43-
CaptureTimeNs()
44-
{
45-
return std::chrono::duration_cast<std::chrono::nanoseconds>(
46-
std::chrono::steady_clock::now().time_since_epoch())
47-
.count();
48-
}
49-
5042
bool
5143
IsStaleState(Payload::State payload_state)
5244
{
@@ -753,32 +745,9 @@ DynamicBatchScheduler::CacheLookUp(
753745
std::unique_ptr<InferenceRequest>& request,
754746
std::unique_ptr<InferenceResponse>& cached_response)
755747
{
756-
Status status;
757748
auto cache = model_->Server()->CacheManager()->Cache();
758-
std::unique_ptr<InferenceResponse> local_response;
759-
request->ResponseFactory()->CreateResponse(&local_response);
760-
// Hash request into cache key
761-
std::string key = "";
762-
if (!request->CacheKeyIsSet()) {
763-
status = cache->Hash(*request, &key);
764-
if (!status.IsOk()) {
765-
LOG_ERROR << "Failed to hash request: " << status.Message();
766-
return;
767-
}
768-
request->SetCacheKey(key);
769-
} else {
770-
key = request->CacheKey();
771-
}
772-
773-
// Lookup and capture timestamps
774-
{
775-
request->CaptureCacheLookupStartNs();
776-
status = cache->Lookup(local_response.get(), key);
777-
request->CaptureCacheLookupEndNs();
778-
}
779-
780-
if (status.IsOk() && (local_response != nullptr)) {
781-
cached_response = std::move(local_response);
749+
bool is_lookup_success = CacheLookUpUtil(request, cached_response, cache);
750+
if (is_lookup_success) {
782751
#ifdef TRITON_ENABLE_STATS
783752
// Update model metrics/stats on cache hits
784753
// Backends will update metrics as normal on cache misses

src/ensemble_scheduler/ensemble_scheduler.cc

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,10 @@ class RequestTracker {
6161

6262
std::unique_ptr<InferenceRequest>& Request() { return request_; }
6363

64+
InferenceStatsAggregator* StatsAggregator() { return stats_aggregator_; }
65+
66+
MetricModelReporter* MetricReporter() { return metric_reporter_; }
67+
6468
InferenceStatsAggregator& ContextStatsAggregator()
6569
{
6670
return context_stats_aggregator_;
@@ -316,6 +320,9 @@ class EnsembleContext {
316320
const std::set<std::pair<std::string, IterationCount>>& updated_tensors,
317321
std::unique_ptr<InferenceResponse>* response);
318322

323+
void CacheEnsembleTopLevelRequest(
324+
std::unique_ptr<InferenceResponse>& response);
325+
319326
InferenceServer* is_;
320327

321328
EnsembleInfo* info_;
@@ -1033,6 +1040,50 @@ EnsembleContext::ReshapeTensorDims(
10331040
return res;
10341041
}
10351042

1043+
// Caching function
1044+
void
1045+
EnsembleContext::CacheEnsembleTopLevelRequest(
1046+
std::unique_ptr<InferenceResponse>& response)
1047+
{
1048+
const std::string key = request_tracker_->Request()->CacheKey();
1049+
const bool is_key_set = request_tracker_->Request()->CacheKeyIsSet();
1050+
1051+
#ifdef TRITON_ENABLE_STATS
1052+
const uint64_t lookup_end_ns =
1053+
request_tracker_->Request()->CacheLookupEndNs();
1054+
const uint64_t lookup_start_ns =
1055+
request_tracker_->Request()->CacheLookupStartNs();
1056+
#endif
1057+
1058+
if (!is_key_set) {
1059+
LOG_ERROR << "Request cache key was not set correctly.";
1060+
}
1061+
1062+
auto cache = is_->CacheManager()->Cache();
1063+
#ifdef TRITON_ENABLE_STATS
1064+
const uint64_t insert_start_ns = CaptureTimeNs();
1065+
#endif
1066+
auto status = cache->Insert(response.get(), key);
1067+
if (!status.IsOk()) {
1068+
LOG_ERROR << "Failed to insert key [" << key
1069+
<< "] into response cache: " << status.Message();
1070+
}
1071+
1072+
#ifdef TRITON_ENABLE_STATS
1073+
const uint64_t insert_end_ns = CaptureTimeNs();
1074+
uint64_t lookup_ns = lookup_end_ns - lookup_start_ns;
1075+
if (lookup_start_ns > lookup_end_ns) {
1076+
lookup_ns = 0;
1077+
LOG_ERROR << "Request lookup duration was not set correctly.";
1078+
}
1079+
uint64_t insert_ns = insert_end_ns - insert_start_ns;
1080+
uint64_t cache_miss_ns = lookup_ns + insert_ns;
1081+
request_tracker_->StatsAggregator()->UpdateSuccessCacheMiss(
1082+
request_tracker_->MetricReporter(), cache_miss_ns);
1083+
#endif
1084+
}
1085+
1086+
10361087
Status
10371088
EnsembleContext::FinishEnsemble(std::unique_ptr<InferenceResponse>&& response)
10381089
{
@@ -1053,6 +1104,10 @@ EnsembleContext::FinishEnsemble(std::unique_ptr<InferenceResponse>&& response)
10531104
? TRITONSERVER_RESPONSE_COMPLETE_FINAL
10541105
: 0;
10551106
if (response != nullptr) {
1107+
// Cache the request if caching is enabled.
1108+
if (info_->is_cache_enabled_) {
1109+
CacheEnsembleTopLevelRequest(response);
1110+
}
10561111
InferenceResponse::Send(std::move(response), flags);
10571112
response_sent_ = true;
10581113
} else if (flags != 0) {
@@ -1319,6 +1374,21 @@ EnsembleScheduler::Create(
13191374
return Status::Success;
13201375
}
13211376

1377+
1378+
void
1379+
EnsembleScheduler::CacheLookUp(
1380+
std::unique_ptr<InferenceRequest>& request,
1381+
std::unique_ptr<InferenceResponse>& cached_response)
1382+
{
1383+
auto cache = is_->CacheManager()->Cache();
1384+
bool is_lookup_success = CacheLookUpUtil(request, cached_response, cache);
1385+
if (is_lookup_success) {
1386+
#ifdef TRITON_ENABLE_STATS
1387+
request->ReportStatisticsCacheHit(metric_reporter_.get());
1388+
#endif
1389+
}
1390+
}
1391+
13221392
Status
13231393
EnsembleScheduler::Enqueue(std::unique_ptr<InferenceRequest>& request)
13241394
{
@@ -1333,6 +1403,19 @@ EnsembleScheduler::Enqueue(std::unique_ptr<InferenceRequest>& request)
13331403
TRITONSERVER_TRACE_TENSOR_QUEUE_INPUT, "EnsembleScheduler Enqueue");
13341404
#endif // TRITON_ENABLE_TRACING
13351405

1406+
std::unique_ptr<InferenceResponse> cached_response;
1407+
if (info_->is_cache_enabled_) {
1408+
CacheLookUp(request, cached_response);
1409+
}
1410+
1411+
if (cached_response != nullptr) {
1412+
InferenceResponse::Send(
1413+
std::move(cached_response), TRITONSERVER_RESPONSE_COMPLETE_FINAL);
1414+
InferenceRequest::Release(
1415+
std::move(request), TRITONSERVER_REQUEST_RELEASE_ALL);
1416+
return Status::Success;
1417+
}
1418+
13361419
// Add additional callback to keep track of in-flight count
13371420
++inflight_count_;
13381421
request->AddInternalReleaseCallback(
@@ -1387,6 +1470,10 @@ EnsembleScheduler::EnsembleScheduler(
13871470
// This config field is filled internally for ensemble models
13881471
info_->is_decoupled_ = config.model_transaction_policy().decoupled();
13891472

1473+
// field to check if response cache enabled in the ensemble model config.
1474+
info_->is_cache_enabled_ =
1475+
config.response_cache().enable() && is_->ResponseCacheEnabled();
1476+
13901477
for (const auto& input : config.input()) {
13911478
info_->tensor_to_step_.emplace(input.name(), std::set<size_t>());
13921479
if (input.optional()) {

src/ensemble_scheduler/ensemble_scheduler.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
#include "model_config.pb.h"
3535
#include "model_config_utils.h"
3636
#include "scheduler.h"
37+
#include "scheduler_utils.h"
3738
#include "status.h"
3839

3940
#ifdef TRITON_ENABLE_GPU
@@ -65,6 +66,8 @@ struct EnsembleInfo {
6566

6667
bool is_decoupled_;
6768

69+
bool is_cache_enabled_;
70+
6871
// the ensemble output (re)shape expected by the ensemble
6972
std::unordered_map<std::string, triton::common::DimsList>
7073
ensemble_output_shape_;
@@ -97,6 +100,7 @@ class EnsembleScheduler : public Scheduler {
97100
// \see Scheduler::Enqueue()
98101
Status Enqueue(std::unique_ptr<InferenceRequest>& request) override;
99102

103+
100104
// \see Scheduler::InflightInferenceCount()
101105
size_t InflightInferenceCount() override { return inflight_count_; }
102106

@@ -109,6 +113,10 @@ class EnsembleScheduler : public Scheduler {
109113
InferenceServer* const server, const ModelIdentifier& model_id,
110114
const inference::ModelConfig& config);
111115

116+
void CacheLookUp(
117+
std::unique_ptr<InferenceRequest>& request,
118+
std::unique_ptr<InferenceResponse>& cached_response);
119+
112120
std::shared_ptr<MetricModelReporter> metric_reporter_;
113121
InferenceStatsAggregator* const stats_aggregator_;
114122
InferenceServer* const is_;

src/scheduler_utils.cc

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,44 @@
3333

3434
namespace triton { namespace core {
3535

36+
uint64_t
37+
CaptureTimeNs()
38+
{
39+
return std::chrono::duration_cast<std::chrono::nanoseconds>(
40+
std::chrono::steady_clock::now().time_since_epoch())
41+
.count();
42+
}
43+
44+
bool
45+
CacheLookUpUtil(
46+
std::unique_ptr<InferenceRequest>& request,
47+
std::unique_ptr<InferenceResponse>& cached_response,
48+
std::shared_ptr<TritonCache> cache)
49+
{
50+
Status status;
51+
std::unique_ptr<InferenceResponse> local_response;
52+
request->ResponseFactory()->CreateResponse(&local_response);
53+
std::string key = "";
54+
if (!request->CacheKeyIsSet()) {
55+
status = cache->Hash(*request, &key);
56+
if (!status.IsOk()) {
57+
LOG_ERROR << "Failed to hash request: " << status.Message();
58+
return false;
59+
}
60+
request->SetCacheKey(key);
61+
} else {
62+
key = request->CacheKey();
63+
}
64+
request->CaptureCacheLookupStartNs();
65+
status = cache->Lookup(local_response.get(), key);
66+
request->CaptureCacheLookupEndNs();
67+
if (status.IsOk() && (local_response != nullptr)) {
68+
cached_response = std::move(local_response);
69+
return true;
70+
}
71+
return false;
72+
}
73+
3674
Status
3775
RequiredEqualInputs::Initialize(
3876
const std::unique_ptr<InferenceRequest>& request,

src/scheduler_utils.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,19 @@
2828
#include <deque>
2929
#include <unordered_map>
3030

31+
#include "cache_manager.h"
3132
#include "scheduler.h"
3233

3334
namespace triton { namespace core {
3435

36+
uint64_t CaptureTimeNs();
37+
// Utility function called by the scheduler to lookup if the request is in the
38+
// cache and get the response.
39+
bool CacheLookUpUtil(
40+
std::unique_ptr<InferenceRequest>& request,
41+
std::unique_ptr<InferenceResponse>& cached_response,
42+
std::shared_ptr<TritonCache> cache);
43+
3544
struct RequiredEqualInputs {
3645
public:
3746
RequiredEqualInputs() : init_(false), has_optional_input_(false) {}

0 commit comments

Comments
 (0)