@@ -223,7 +223,7 @@ modelbox::Status MindSporeInference::Init(
223
223
for (auto &input_tensor : model_->GetInputs ()) {
224
224
// check model info & whether padding
225
225
if (!model_need_padding_ && input_tensor.Shape ()[0 ] > 1 ) {
226
- model_need_padding_ = true ;
226
+ model_need_padding_ = !multi_batch_in_buffer_ ;
227
227
auto max_batch_size = input_tensor.Shape ()[0 ];
228
228
if (config_batch_size_ > max_batch_size) {
229
229
auto error_msg =
@@ -283,6 +283,9 @@ modelbox::Status MindSporeInference::Open(
283
283
config_file_ = relpath + " /" + config_file_;
284
284
}
285
285
286
+ multi_batch_in_buffer_ =
287
+ merge_config->GetBool (" config.multi_batch_in_buffer" , false );
288
+
286
289
ret = Init (unit_desc->GetModelEntry (), merge_config,
287
290
flowunit_device_->GetDeviceManager ()->GetDrivers ());
288
291
if (ret != modelbox::STATUS_SUCCESS) {
@@ -384,7 +387,7 @@ void MindSporeInference::PrepareInputTensor(
384
387
const_cast <void *>(input_buffer_list->ConstData ()));
385
388
}
386
389
// set current batch size
387
- if (!model_need_padding_) {
390
+ if (!multi_batch_in_buffer_ && ! model_need_padding_) {
388
391
input_shape[0 ] = input_buffer_list->Size ();
389
392
}
390
393
MBLOG_DEBUG << " input name: " << name << " shape: " ;
@@ -474,6 +477,10 @@ modelbox::Status MindSporeInference::PrepareOutputBufferList(
474
477
size_t buffer_output_batch = tensor_output_batch - padding_batch_size_;
475
478
size_t output_batch_bytes =
476
479
ms_outputs[i].DataSize () / ms_outputs[i].Shape ()[0 ];
480
+ if (multi_batch_in_buffer_) {
481
+ buffer_output_batch = 1 ;
482
+ output_batch_bytes = ms_outputs[i].DataSize ();
483
+ }
477
484
MBLOG_DEBUG << " tensor_output_batch:" << tensor_output_batch
478
485
<< " , padding_batch_size:" << padding_batch_size_
479
486
<< " , output_batch_size:" << buffer_output_batch;
@@ -491,7 +498,9 @@ modelbox::Status MindSporeInference::PrepareOutputBufferList(
491
498
492
499
auto tensor_shape = ms_outputs[i].Shape ();
493
500
std::vector<size_t > output_shape;
494
- tensor_shape[0 ] = 1 ;
501
+ if (!multi_batch_in_buffer_) {
502
+ tensor_shape[0 ] = 1 ;
503
+ }
495
504
MBLOG_DEBUG << " output shape: " ;
496
505
for (const auto &item : tensor_shape) {
497
506
output_shape.push_back (item);
0 commit comments