Skip to content

Commit 2bd46a4

Browse files
authored
Improve input mismatch error for inference requests (DLIS-6165) (#330)
Improved the response of the server incase of mismatch in required inputs for a model.
1 parent 9f1fad2 commit 2bd46a4

File tree

2 files changed

+67
-24
lines changed

2 files changed

+67
-24
lines changed

src/infer_request.cc

Lines changed: 63 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright 2020-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1+
// Copyright 2020-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
//
33
// Redistribution and use in source and binary forms, with or without
44
// modification, are permitted provided that the following conditions
@@ -1001,28 +1001,7 @@ InferenceRequest::Normalize()
10011001
}
10021002
// Make sure that the request is providing the number of inputs
10031003
// as is expected by the model.
1004-
if ((original_inputs_.size() > (size_t)model_config.input_size()) ||
1005-
(original_inputs_.size() < model_raw_->RequiredInputCount())) {
1006-
// If no input is marked as optional, then use exact match error message
1007-
// for consistency / backward compatibility
1008-
if ((size_t)model_config.input_size() == model_raw_->RequiredInputCount()) {
1009-
return Status(
1010-
Status::Code::INVALID_ARG,
1011-
LogRequest() + "expected " +
1012-
std::to_string(model_config.input_size()) + " inputs but got " +
1013-
std::to_string(original_inputs_.size()) + " inputs for model '" +
1014-
ModelName() + "'");
1015-
} else {
1016-
return Status(
1017-
Status::Code::INVALID_ARG,
1018-
LogRequest() + "expected number of inputs between " +
1019-
std::to_string(model_raw_->RequiredInputCount()) + " and " +
1020-
std::to_string(model_config.input_size()) + " but got " +
1021-
std::to_string(original_inputs_.size()) + " inputs for model '" +
1022-
ModelName() + "'");
1023-
}
1024-
}
1025-
1004+
RETURN_IF_ERROR(ValidateRequestInputs());
10261005
// Determine the batch size and shape of each input.
10271006
if (model_config.max_batch_size() == 0) {
10281007
// Model does not support Triton-style batching so set as
@@ -1195,6 +1174,67 @@ InferenceRequest::Normalize()
11951174
return Status::Success;
11961175
}
11971176

1177+
Status
1178+
InferenceRequest::ValidateRequestInputs()
1179+
{
1180+
const inference::ModelConfig& model_config = model_raw_->Config();
1181+
if ((original_inputs_.size() > (size_t)model_config.input_size()) ||
1182+
(original_inputs_.size() < model_raw_->RequiredInputCount())) {
1183+
// If no input is marked as optional, then use exact match error message
1184+
// for consistency / backward compatibility
1185+
std::string missing_required_input_string = "[";
1186+
std::string original_input_string = "[";
1187+
1188+
for (size_t i = 0; i < (size_t)model_config.input_size(); ++i) {
1189+
const inference::ModelInput& input = model_config.input(i);
1190+
if ((!input.optional()) &&
1191+
(original_inputs_.find(input.name()) == original_inputs_.end())) {
1192+
missing_required_input_string =
1193+
missing_required_input_string + "'" + input.name() + "'" + ",";
1194+
}
1195+
}
1196+
// Removes the extra ","
1197+
missing_required_input_string.pop_back();
1198+
missing_required_input_string = missing_required_input_string + "]";
1199+
1200+
for (const auto& pair : original_inputs_) {
1201+
original_input_string =
1202+
original_input_string + "'" + pair.first + "'" + ",";
1203+
}
1204+
// Removes the extra ","
1205+
original_input_string.pop_back();
1206+
original_input_string = original_input_string + "]";
1207+
if (original_inputs_.size() == 0) {
1208+
original_input_string = "[]";
1209+
}
1210+
if ((size_t)model_config.input_size() == model_raw_->RequiredInputCount()) {
1211+
// This is response ONLY when there are no optional parameters in the
1212+
// model
1213+
return Status(
1214+
Status::Code::INVALID_ARG,
1215+
LogRequest() + "expected " +
1216+
std::to_string(model_config.input_size()) + " inputs but got " +
1217+
std::to_string(original_inputs_.size()) + " inputs for model '" +
1218+
ModelName() + "'. Got input(s) " + original_input_string +
1219+
", but missing required input(s) " +
1220+
missing_required_input_string +
1221+
". Please provide all required input(s).");
1222+
} else {
1223+
return Status(
1224+
Status::Code::INVALID_ARG,
1225+
LogRequest() + "expected number of inputs between " +
1226+
std::to_string(model_raw_->RequiredInputCount()) + " and " +
1227+
std::to_string(model_config.input_size()) + " but got " +
1228+
std::to_string(original_inputs_.size()) + " inputs for model '" +
1229+
ModelName() + "'. Got input(s) " + original_input_string +
1230+
", but missing required input(s) " +
1231+
missing_required_input_string +
1232+
". Please provide all required input(s).");
1233+
}
1234+
}
1235+
return Status::Success;
1236+
}
1237+
11981238
#ifdef TRITON_ENABLE_STATS
11991239
void
12001240
InferenceRequest::ReportStatistics(

src/infer_request.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright 2020-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1+
// Copyright 2020-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
//
33
// Redistribution and use in source and binary forms, with or without
44
// modification, are permitted provided that the following conditions
@@ -744,6 +744,9 @@ class InferenceRequest {
744744

745745
Status Normalize();
746746

747+
// Helper for validating Inputs
748+
Status ValidateRequestInputs();
749+
747750
// Helpers for pending request metrics
748751
void IncrementPendingRequestCount();
749752
void DecrementPendingRequestCount();

0 commit comments

Comments
 (0)