|
1 |
| -// Copyright 2020-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. |
| 1 | +// Copyright 2020-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. |
2 | 2 | //
|
3 | 3 | // Redistribution and use in source and binary forms, with or without
|
4 | 4 | // modification, are permitted provided that the following conditions
|
@@ -1001,28 +1001,7 @@ InferenceRequest::Normalize()
|
1001 | 1001 | }
|
1002 | 1002 | // Make sure that the request is providing the number of inputs
|
1003 | 1003 | // as is expected by the model.
|
1004 |
| - if ((original_inputs_.size() > (size_t)model_config.input_size()) || |
1005 |
| - (original_inputs_.size() < model_raw_->RequiredInputCount())) { |
1006 |
| - // If no input is marked as optional, then use exact match error message |
1007 |
| - // for consistency / backward compatibility |
1008 |
| - if ((size_t)model_config.input_size() == model_raw_->RequiredInputCount()) { |
1009 |
| - return Status( |
1010 |
| - Status::Code::INVALID_ARG, |
1011 |
| - LogRequest() + "expected " + |
1012 |
| - std::to_string(model_config.input_size()) + " inputs but got " + |
1013 |
| - std::to_string(original_inputs_.size()) + " inputs for model '" + |
1014 |
| - ModelName() + "'"); |
1015 |
| - } else { |
1016 |
| - return Status( |
1017 |
| - Status::Code::INVALID_ARG, |
1018 |
| - LogRequest() + "expected number of inputs between " + |
1019 |
| - std::to_string(model_raw_->RequiredInputCount()) + " and " + |
1020 |
| - std::to_string(model_config.input_size()) + " but got " + |
1021 |
| - std::to_string(original_inputs_.size()) + " inputs for model '" + |
1022 |
| - ModelName() + "'"); |
1023 |
| - } |
1024 |
| - } |
1025 |
| - |
| 1004 | + RETURN_IF_ERROR(ValidateRequestInputs()); |
1026 | 1005 | // Determine the batch size and shape of each input.
|
1027 | 1006 | if (model_config.max_batch_size() == 0) {
|
1028 | 1007 | // Model does not support Triton-style batching so set as
|
@@ -1195,6 +1174,67 @@ InferenceRequest::Normalize()
|
1195 | 1174 | return Status::Success;
|
1196 | 1175 | }
|
1197 | 1176 |
|
| 1177 | +Status |
| 1178 | +InferenceRequest::ValidateRequestInputs() |
| 1179 | +{ |
| 1180 | + const inference::ModelConfig& model_config = model_raw_->Config(); |
| 1181 | + if ((original_inputs_.size() > (size_t)model_config.input_size()) || |
| 1182 | + (original_inputs_.size() < model_raw_->RequiredInputCount())) { |
| 1183 | + // If no input is marked as optional, then use exact match error message |
| 1184 | + // for consistency / backward compatibility |
| 1185 | + std::string missing_required_input_string = "["; |
| 1186 | + std::string original_input_string = "["; |
| 1187 | + |
| 1188 | + for (size_t i = 0; i < (size_t)model_config.input_size(); ++i) { |
| 1189 | + const inference::ModelInput& input = model_config.input(i); |
| 1190 | + if ((!input.optional()) && |
| 1191 | + (original_inputs_.find(input.name()) == original_inputs_.end())) { |
| 1192 | + missing_required_input_string = |
| 1193 | + missing_required_input_string + "'" + input.name() + "'" + ","; |
| 1194 | + } |
| 1195 | + } |
| 1196 | + // Removes the extra "," |
| 1197 | + missing_required_input_string.pop_back(); |
| 1198 | + missing_required_input_string = missing_required_input_string + "]"; |
| 1199 | + |
| 1200 | + for (const auto& pair : original_inputs_) { |
| 1201 | + original_input_string = |
| 1202 | + original_input_string + "'" + pair.first + "'" + ","; |
| 1203 | + } |
| 1204 | + // Removes the extra "," |
| 1205 | + original_input_string.pop_back(); |
| 1206 | + original_input_string = original_input_string + "]"; |
| 1207 | + if (original_inputs_.size() == 0) { |
| 1208 | + original_input_string = "[]"; |
| 1209 | + } |
| 1210 | + if ((size_t)model_config.input_size() == model_raw_->RequiredInputCount()) { |
| 1211 | + // This is response ONLY when there are no optional parameters in the |
| 1212 | + // model |
| 1213 | + return Status( |
| 1214 | + Status::Code::INVALID_ARG, |
| 1215 | + LogRequest() + "expected " + |
| 1216 | + std::to_string(model_config.input_size()) + " inputs but got " + |
| 1217 | + std::to_string(original_inputs_.size()) + " inputs for model '" + |
| 1218 | + ModelName() + "'. Got input(s) " + original_input_string + |
| 1219 | + ", but missing required input(s) " + |
| 1220 | + missing_required_input_string + |
| 1221 | + ". Please provide all required input(s)."); |
| 1222 | + } else { |
| 1223 | + return Status( |
| 1224 | + Status::Code::INVALID_ARG, |
| 1225 | + LogRequest() + "expected number of inputs between " + |
| 1226 | + std::to_string(model_raw_->RequiredInputCount()) + " and " + |
| 1227 | + std::to_string(model_config.input_size()) + " but got " + |
| 1228 | + std::to_string(original_inputs_.size()) + " inputs for model '" + |
| 1229 | + ModelName() + "'. Got input(s) " + original_input_string + |
| 1230 | + ", but missing required input(s) " + |
| 1231 | + missing_required_input_string + |
| 1232 | + ". Please provide all required input(s)."); |
| 1233 | + } |
| 1234 | + } |
| 1235 | + return Status::Success; |
| 1236 | +} |
| 1237 | + |
1198 | 1238 | #ifdef TRITON_ENABLE_STATS
|
1199 | 1239 | void
|
1200 | 1240 | InferenceRequest::ReportStatistics(
|
|
0 commit comments