@@ -63,8 +63,6 @@ modelbox::Status AtcInference::Init(
63
63
modelbox::Status AtcInference::ParseConfig (
64
64
const std::shared_ptr<modelbox::Configuration> &config) {
65
65
device_id_ = config->GetInt32 (" deviceid" );
66
- batch_size_ = config->GetInt32 (" batch_size" , 1 );
67
- MBLOG_INFO << " Model batch size " << batch_size_;
68
66
return modelbox::STATUS_SUCCESS;
69
67
}
70
68
@@ -126,17 +124,25 @@ modelbox::Status AtcInference::CheckModelIO(
126
124
unit_input_list.end ());
127
125
std::set<std::string> unit_output_set (unit_output_list.begin (),
128
126
unit_output_list.end ());
127
+
128
+ auto model_input_size = dynamic_batch_tensor_index_ >= 0
129
+ ? model_input_list_.size () - 1
130
+ : model_input_list_.size ();
131
+
129
132
if (model_input_list_.empty () || model_output_list_.empty () ||
130
- model_input_list_. size () != unit_input_list.size () ||
133
+ model_input_size != unit_input_list.size () ||
131
134
model_output_list_.size () != unit_output_list.size ()) {
132
- MBLOG_ERROR << " Model input[" << model_input_list_. size () << " ], output["
135
+ MBLOG_ERROR << " Model input[" << model_input_size << " ], output["
133
136
<< model_output_list_.size () << " ], FlowUnit input["
134
137
<< unit_input_list.size () << " ], output["
135
138
<< unit_output_list.size () << " ], these io count is bad" ;
136
139
return modelbox::STATUS_BADCONF;
137
140
}
138
141
139
142
for (auto &model_input_name : model_input_list_) {
143
+ if (model_input_name == ACL_DYNAMIC_TENSOR_NAME) {
144
+ continue ;
145
+ }
140
146
if (unit_input_set.find (model_input_name) == unit_input_set.end ()) {
141
147
MBLOG_ERROR << " Model miss input [" << model_input_name
142
148
<< " ] in graph config" ;
@@ -161,7 +167,8 @@ void AtcInference::ReadModelInfo() {
161
167
std::stringstream model_info;
162
168
model_info << " Model:" << model_file_ << std::endl;
163
169
auto *desc_ptr = model_desc_.get ();
164
- LogBatchInfo (desc_ptr, model_info);
170
+ size_t max_batch_size = 1 ;
171
+ SaveBatchInfo (desc_ptr, model_info, max_batch_size);
165
172
model_info << " Input:" << std::endl;
166
173
aclmdlIODims dims;
167
174
for (size_t i = 0 ; i < input_num; ++i) {
@@ -178,6 +185,17 @@ void AtcInference::ReadModelInfo() {
178
185
model_input_size_.push_back (size);
179
186
auto format = aclmdlGetInputFormat (desc_ptr, i);
180
187
auto data_type = aclmdlGetInputDataType (desc_ptr, i);
188
+
189
+ if (name == ACL_DYNAMIC_TENSOR_NAME) {
190
+ dynamic_batch_tensor_index_ = i;
191
+ ret = aclrtMalloc (&dynamic_batch_mem_ptr_, model_input_size_[i],
192
+ aclrtMemMallocPolicy::ACL_MEM_MALLOC_NORMAL_ONLY);
193
+ if (ret != ACL_SUCCESS || dynamic_batch_mem_ptr_ == nullptr ) {
194
+ MBLOG_ERROR << " malloc acl memory failed, size: " << size;
195
+ return ;
196
+ }
197
+ }
198
+
181
199
LogTensorInfo (desc_ptr, i, dims, size, format, data_type, model_info);
182
200
}
183
201
@@ -194,7 +212,14 @@ void AtcInference::ReadModelInfo() {
194
212
std::string name = aclmdlGetOutputNameByIndex (desc_ptr, i);
195
213
model_output_list_.push_back (name);
196
214
auto size = aclmdlGetOutputSizeByIndex (desc_ptr, i);
197
- model_output_size_.push_back (size);
215
+ if (dynamic_batch_tensor_index_ >= 0 && dims.dimCount > 0 &&
216
+ max_batch_size != (size_t )dims.dims [0 ]) {
217
+ MBLOG_ERROR << " model output tensor [" << name
218
+ << " ] dims error, first dims: " << dims.dims [0 ]
219
+ << " is not same with max_batch_size: " << max_batch_size;
220
+ }
221
+
222
+ model_output_size_.push_back (size / max_batch_size);
198
223
auto format = aclmdlGetOutputFormat (desc_ptr, i);
199
224
auto data_type = aclmdlGetOutputDataType (desc_ptr, i);
200
225
output_data_type_.push_back (GetModelBoxDataType (data_type));
@@ -213,15 +238,21 @@ void AtcInference::SaveOutputShape(const aclmdlIODims &dims) {
213
238
output_shape_.push_back (shape);
214
239
}
215
240
216
- void AtcInference::LogBatchInfo (aclmdlDesc *desc_ptr,
217
- std::stringstream &model_info) {
241
+ void AtcInference::SaveBatchInfo (aclmdlDesc *desc_ptr,
242
+ std::stringstream &model_info,
243
+ size_t &max_batch_size) {
218
244
aclmdlBatch batch;
219
245
auto ret = aclmdlGetDynamicBatch (desc_ptr, &batch);
220
246
if (ret != ACL_ERROR_NONE) {
221
247
model_info << " Get dynamic batch failed, ret " << ret;
222
248
} else {
223
249
model_info << " Dynamic batch:[" ;
250
+ size_t size = 0 ;
224
251
for (size_t i = 0 ; i < batch.batchCount ; ++i) {
252
+ dynamic_batch_set_.emplace (batch.batch [i]);
253
+ if (size < batch.batch [i]) {
254
+ size = batch.batch [i];
255
+ }
225
256
model_info << batch.batch [i];
226
257
if (i + 1 == batch.batchCount ) {
227
258
model_info << " ]" ;
@@ -232,6 +263,8 @@ void AtcInference::LogBatchInfo(aclmdlDesc *desc_ptr,
232
263
233
264
if (batch.batchCount == 0 ) {
234
265
model_info << " ]" ;
266
+ } else {
267
+ max_batch_size = size;
235
268
}
236
269
}
237
270
model_info << std::endl;
@@ -290,6 +323,31 @@ modelbox::ModelBoxDataType AtcInference::GetModelBoxDataType(
290
323
return modelbox::ModelBoxDataType::MODELBOX_TYPE_INVALID;
291
324
}
292
325
326
+ modelbox::Status AtcInference::GetCurrentBatchSize (
327
+ std::shared_ptr<modelbox::DataContext> &data_ctx, size_t &batch_size) {
328
+ if (model_input_list_.size () == 0 ) {
329
+ MBLOG_ERROR << " model_input_list_ is empty " ;
330
+ return modelbox::STATUS_FAULT;
331
+ }
332
+
333
+ auto buffer_list = data_ctx->Input ()->at (model_input_list_[0 ]);
334
+ if (buffer_list == nullptr ) {
335
+ MBLOG_ERROR << " get current batch size failed " ;
336
+ return modelbox::STATUS_FAULT;
337
+ }
338
+
339
+ if ((dynamic_batch_tensor_index_ >= 0 ) &&
340
+ (dynamic_batch_set_.find (buffer_list->Size ()) ==
341
+ dynamic_batch_set_.end ())) {
342
+ MBLOG_ERROR << " current model is not support input batch_size: "
343
+ << buffer_list->Size ();
344
+ return modelbox::STATUS_FAULT;
345
+ }
346
+
347
+ batch_size = buffer_list->Size ();
348
+ return modelbox::STATUS_OK;
349
+ }
350
+
293
351
modelbox::Status AtcInference::Infer (
294
352
std::shared_ptr<modelbox::DataContext> &data_ctx, aclrtStream stream) {
295
353
auto acl_ret = aclrtSetDevice (device_id_);
@@ -299,24 +357,43 @@ modelbox::Status AtcInference::Infer(
299
357
return {modelbox::STATUS_FAULT, " Set device failed" };
300
358
}
301
359
302
- auto input = CreateDataSet (data_ctx->Input (), model_input_list_);
360
+ size_t current_batch_size;
361
+ auto ret = GetCurrentBatchSize (data_ctx, current_batch_size);
362
+ if (ret != modelbox::STATUS_SUCCESS) {
363
+ MBLOG_ERROR << " get current batch size failed" ;
364
+ return {modelbox::STATUS_FAULT, " Get current batch size failed" };
365
+ }
366
+
367
+ auto input =
368
+ CreateDataSet (data_ctx->Input (), model_input_list_, current_batch_size);
303
369
if (input == nullptr ) {
304
370
MBLOG_ERROR << " Create input for infer failed" ;
305
371
return {modelbox::STATUS_FAULT, " Create input failed" };
306
372
}
307
373
308
- auto ret = PrepareOutput (data_ctx);
374
+ ret = PrepareOutput (data_ctx, current_batch_size );
309
375
if (ret != modelbox::STATUS_SUCCESS) {
310
376
MBLOG_ERROR << " Prepare output failed" ;
311
377
return {modelbox::STATUS_FAULT, " Prepare output failed" };
312
378
}
313
379
314
- auto output = CreateDataSet (data_ctx->Output (), model_output_list_);
380
+ auto output =
381
+ CreateDataSet (data_ctx->Output (), model_output_list_, current_batch_size);
315
382
if (output == nullptr ) {
316
383
MBLOG_ERROR << " Create output for infer failed" ;
317
384
return {modelbox::STATUS_FAULT, " Create output failed" };
318
385
}
319
386
387
+ if (dynamic_batch_tensor_index_ >= 0 ) {
388
+ acl_ret = aclmdlSetDynamicBatchSize (model_id_, input.get (),
389
+ dynamic_batch_tensor_index_,
390
+ current_batch_size);
391
+ if (acl_ret != ACL_ERROR_NONE) {
392
+ MBLOG_ERROR << " aclmdlSetDynamicBatchSize failed, ret " << acl_ret;
393
+ return {modelbox::STATUS_FAULT, " Execute acl set batch_size failed" };
394
+ }
395
+ }
396
+
320
397
acl_ret = ACL_ERROR_NONE;
321
398
if (stream == nullptr ) {
322
399
acl_ret = aclmdlExecute (model_id_, input.get (), output.get ());
@@ -334,13 +411,14 @@ modelbox::Status AtcInference::Infer(
334
411
}
335
412
336
413
modelbox::Status AtcInference::PrepareOutput (
337
- std::shared_ptr<modelbox::DataContext> &data_ctx) {
414
+ std::shared_ptr<modelbox::DataContext> &data_ctx,
415
+ const size_t ¤t_batch_size) {
338
416
auto output_count = model_output_list_.size ();
339
417
for (size_t i = 0 ; i < output_count; ++i) {
340
418
auto &name = model_output_list_[i];
341
419
auto buffer_list = data_ctx->Output (name);
342
420
auto &size = model_output_size_[i];
343
- std::vector<size_t > shape (batch_size_ , size);
421
+ std::vector<size_t > shape (current_batch_size , size);
344
422
buffer_list->Build (shape);
345
423
buffer_list->Set (" shape" , output_shape_[i]);
346
424
buffer_list->Set (" type" , output_data_type_[i]);
@@ -351,7 +429,7 @@ modelbox::Status AtcInference::PrepareOutput(
351
429
352
430
std::shared_ptr<aclmdlDataset> AtcInference::CreateDataSet (
353
431
const std::shared_ptr<modelbox::BufferListMap> &buffer_list_map,
354
- std::vector<std::string> &name_list) {
432
+ std::vector<std::string> &name_list, const size_t ¤t_batch_size ) {
355
433
auto *data_set_ptr = aclmdlCreateDataset ();
356
434
if (data_set_ptr == nullptr ) {
357
435
MBLOG_ERROR << " aclmdlCreateDataset return null" ;
@@ -368,15 +446,32 @@ std::shared_ptr<aclmdlDataset> AtcInference::CreateDataSet(
368
446
});
369
447
370
448
for (auto &tensor_name : name_list) {
371
- auto buffer_list = buffer_list_map->at (tensor_name);
372
- if (buffer_list == nullptr ) {
373
- MBLOG_ERROR << " Create data set for tensor " << tensor_name
374
- << " failed, buffer list is null" ;
375
- return nullptr ;
449
+ void *mem_ptr = nullptr ;
450
+ size_t size;
451
+ if (tensor_name != ACL_DYNAMIC_TENSOR_NAME) {
452
+ auto buffer_list = buffer_list_map->at (tensor_name);
453
+ if (buffer_list == nullptr ) {
454
+ MBLOG_ERROR << " Create data set for tensor " << tensor_name
455
+ << " failed, buffer list is null" ;
456
+ return nullptr ;
457
+ }
458
+
459
+ if (current_batch_size != buffer_list->Size ()) {
460
+ MBLOG_ERROR << " buffer bacth_size is not same, first bacth_size: "
461
+ << current_batch_size
462
+ << " , current tensor: " << tensor_name
463
+ << " bacth_size:" << buffer_list->Size ();
464
+ return nullptr ;
465
+ }
466
+
467
+ mem_ptr = const_cast <void *>(buffer_list->ConstData ());
468
+ size = buffer_list->GetBytes ();
469
+ } else {
470
+ size = model_input_size_[dynamic_batch_tensor_index_];
471
+ mem_ptr = dynamic_batch_mem_ptr_;
376
472
}
377
473
378
- auto *data_buffer = aclCreateDataBuffer (
379
- const_cast <void *>(buffer_list->ConstData ()), buffer_list->GetBytes ());
474
+ auto *data_buffer = aclCreateDataBuffer (mem_ptr, size);
380
475
if (data_buffer == nullptr ) {
381
476
MBLOG_ERROR << " Create data set buffer for tensor " << tensor_name
382
477
<< " failed" ;
@@ -411,5 +506,9 @@ modelbox::Status AtcInference::Deinit() {
411
506
return modelbox::STATUS_FAULT;
412
507
}
413
508
509
+ if (dynamic_batch_mem_ptr_ != nullptr ) {
510
+ aclrtFree (dynamic_batch_mem_ptr_);
511
+ }
512
+
414
513
return modelbox::STATUS_SUCCESS;
415
514
}
0 commit comments