Skip to content

Commit 743f2ff

Browse files
tau233CarlosLeeGit
authored andcommitted
mindspore: set input shape
1 parent 204f406 commit 743f2ff

File tree

1 file changed

+13
-0
lines changed

1 file changed

+13
-0
lines changed

src/drivers/inference_engine/mindspore/mindspore_inference.cc

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -362,6 +362,19 @@ void MindSporeInference::PrepareInputTensor(
362362
MBLOG_DEBUG << "input_buffer_list: " << portname << ", model port: " << name
363363
<< ", size: " << input_buffer_list->Size()
364364
<< ", bytes:" << input_buffer_list->GetBytes();
365+
std::vector<size_t> b_shape;
366+
if (!input_buffer_list->At(0)->Get("shape", b_shape) ||
367+
input_shape.size() != b_shape.size()) {
368+
MBLOG_ERROR << "get input shape failed, tensor shape size:"
369+
<< input_shape.size()
370+
<< ", buffer shape size: " << b_shape.size();
371+
return;
372+
}
373+
374+
for (size_t index = 0; index < b_shape.size(); ++index) {
375+
input_shape[index] = b_shape[index];
376+
}
377+
365378
// input batch padding
366379
if (model_need_padding_) {
367380
padding_batch_size_ = input_shape[0] - input_buffer_list->Size();

0 commit comments

Comments
 (0)