@@ -55,9 +55,14 @@ namespace {
55
55
#ifdef TRITON_ENABLE_ENSEMBLE
56
56
57
57
struct EnsembleTensor {
58
- EnsembleTensor (bool isOutput) : ready(false ), isOutput(isOutput) {}
59
- bool ready;
60
- bool isOutput;
58
+ EnsembleTensor (const std::string& name, bool isOutput)
59
+ : name{name}, isOutput(isOutput)
60
+ {
61
+ }
62
+ const std::string name;
63
+
64
+ bool isOutput{false };
65
+ bool ready{false };
61
66
std::vector<EnsembleTensor*> prev_nodes;
62
67
std::vector<EnsembleTensor*> next_nodes;
63
68
};
@@ -112,10 +117,11 @@ BuildEnsembleGraph(
112
117
it->second .isOutput = true ;
113
118
}
114
119
} else {
115
- it = keyed_ensemble_graph
116
- .emplace (
117
- std::make_pair (output_map.second , EnsembleTensor (true )))
118
- .first ;
120
+ it =
121
+ keyed_ensemble_graph
122
+ .emplace (std::make_pair (
123
+ output_map.second , EnsembleTensor (output_map.second , true )))
124
+ .first ;
119
125
}
120
126
tensor_as_output.push_back (&(it->second ));
121
127
}
@@ -135,8 +141,8 @@ BuildEnsembleGraph(
135
141
auto it = keyed_ensemble_graph.find (input_map.second );
136
142
if (it == keyed_ensemble_graph.end ()) {
137
143
it = keyed_ensemble_graph
138
- .emplace (
139
- std::make_pair ( input_map.second , EnsembleTensor (false )))
144
+ .emplace (std::make_pair (
145
+ input_map.second , EnsembleTensor (input_map. second , false )))
140
146
.first ;
141
147
}
142
148
for (auto output : tensor_as_output) {
@@ -233,10 +239,35 @@ ValidateEnsembleSchedulingConfig(const inference::ModelConfig& config)
233
239
" ' is not used" );
234
240
}
235
241
if (!it->second .ready ) {
236
- return Status (
237
- Status::Code::INVALID_ARG, " output '" + output.name () +
238
- " ' for ensemble '" + config.name () +
239
- " ' is not written" );
242
+ std::string error_message = " output '" + output.name () +
243
+ " ' for ensemble '" + config.name () +
244
+ " ' is not written" ;
245
+
246
+ // recurrsively check 'prev_nodes' for the source of not-ready state
247
+ std::vector<EnsembleTensor*>* prev_nodes = &it->second .prev_nodes ;
248
+ auto last_not_ready_node = &it->second ;
249
+ // there can be circular dependency so remember seen names to break it
250
+ std::set<std::string> seen_names;
251
+ while ((prev_nodes != nullptr ) && (!prev_nodes->empty ())) {
252
+ const auto & nodes = *prev_nodes;
253
+ // make sure while loop will terminate if no not-ready source is seen
254
+ prev_nodes = nullptr ;
255
+ for (const auto & node : nodes) {
256
+ if ((!node->ready ) &&
257
+ (seen_names.find (node->name ) == seen_names.end ())) {
258
+ seen_names.emplace (node->name );
259
+ last_not_ready_node = node;
260
+ prev_nodes = &node->prev_nodes ;
261
+ break ;
262
+ }
263
+ }
264
+ }
265
+ // there is not-ready source
266
+ if (last_not_ready_node->name != it->second .name ) {
267
+ error_message += " : at least one of its depending tensors, '" +
268
+ last_not_ready_node->name + " ', is not connected" ;
269
+ }
270
+ return Status (Status::Code::INVALID_ARG, error_message);
240
271
} else {
241
272
outputs.insert (it->first );
242
273
}
0 commit comments