Skip to content

Commit 669a7cd

Browse files
authored
Add recursive lookup for better ensemble disconnectivity reporting (#239)
* Add recursive lookup for better ensemble disconnectivity reporting * Fix up
1 parent 854da96 commit 669a7cd

File tree

1 file changed

+44
-13
lines changed

1 file changed

+44
-13
lines changed

src/model_config_utils.cc

Lines changed: 44 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -55,9 +55,14 @@ namespace {
5555
#ifdef TRITON_ENABLE_ENSEMBLE
5656

5757
struct EnsembleTensor {
58-
EnsembleTensor(bool isOutput) : ready(false), isOutput(isOutput) {}
59-
bool ready;
60-
bool isOutput;
58+
EnsembleTensor(const std::string& name, bool isOutput)
59+
: name{name}, isOutput(isOutput)
60+
{
61+
}
62+
const std::string name;
63+
64+
bool isOutput{false};
65+
bool ready{false};
6166
std::vector<EnsembleTensor*> prev_nodes;
6267
std::vector<EnsembleTensor*> next_nodes;
6368
};
@@ -112,10 +117,11 @@ BuildEnsembleGraph(
112117
it->second.isOutput = true;
113118
}
114119
} else {
115-
it = keyed_ensemble_graph
116-
.emplace(
117-
std::make_pair(output_map.second, EnsembleTensor(true)))
118-
.first;
120+
it =
121+
keyed_ensemble_graph
122+
.emplace(std::make_pair(
123+
output_map.second, EnsembleTensor(output_map.second, true)))
124+
.first;
119125
}
120126
tensor_as_output.push_back(&(it->second));
121127
}
@@ -135,8 +141,8 @@ BuildEnsembleGraph(
135141
auto it = keyed_ensemble_graph.find(input_map.second);
136142
if (it == keyed_ensemble_graph.end()) {
137143
it = keyed_ensemble_graph
138-
.emplace(
139-
std::make_pair(input_map.second, EnsembleTensor(false)))
144+
.emplace(std::make_pair(
145+
input_map.second, EnsembleTensor(input_map.second, false)))
140146
.first;
141147
}
142148
for (auto output : tensor_as_output) {
@@ -233,10 +239,35 @@ ValidateEnsembleSchedulingConfig(const inference::ModelConfig& config)
233239
"' is not used");
234240
}
235241
if (!it->second.ready) {
236-
return Status(
237-
Status::Code::INVALID_ARG, "output '" + output.name() +
238-
"' for ensemble '" + config.name() +
239-
"' is not written");
242+
std::string error_message = "output '" + output.name() +
243+
"' for ensemble '" + config.name() +
244+
"' is not written";
245+
246+
// recurrsively check 'prev_nodes' for the source of not-ready state
247+
std::vector<EnsembleTensor*>* prev_nodes = &it->second.prev_nodes;
248+
auto last_not_ready_node = &it->second;
249+
// there can be circular dependency so remember seen names to break it
250+
std::set<std::string> seen_names;
251+
while ((prev_nodes != nullptr) && (!prev_nodes->empty())) {
252+
const auto& nodes = *prev_nodes;
253+
// make sure while loop will terminate if no not-ready source is seen
254+
prev_nodes = nullptr;
255+
for (const auto& node : nodes) {
256+
if ((!node->ready) &&
257+
(seen_names.find(node->name) == seen_names.end())) {
258+
seen_names.emplace(node->name);
259+
last_not_ready_node = node;
260+
prev_nodes = &node->prev_nodes;
261+
break;
262+
}
263+
}
264+
}
265+
// there is not-ready source
266+
if (last_not_ready_node->name != it->second.name) {
267+
error_message += ": at least one of its depending tensors, '" +
268+
last_not_ready_node->name + "', is not connected";
269+
}
270+
return Status(Status::Code::INVALID_ARG, error_message);
240271
} else {
241272
outputs.insert(it->first);
242273
}

0 commit comments

Comments
 (0)