Skip to content

Commit 1a2e3d4

Browse files
pansggCarlosLeeGit
authored andcommitted
support dynamic batchsize for acl infer
1 parent ffa8453 commit 1a2e3d4

File tree

2 files changed

+131
-25
lines changed

2 files changed

+131
-25
lines changed

src/drivers/devices/ascend/flowunit/inference/atc_inference.cc

Lines changed: 120 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,6 @@ modelbox::Status AtcInference::Init(
6363
modelbox::Status AtcInference::ParseConfig(
6464
const std::shared_ptr<modelbox::Configuration> &config) {
6565
device_id_ = config->GetInt32("deviceid");
66-
batch_size_ = config->GetInt32("batch_size", 1);
67-
MBLOG_INFO << "Model batch size " << batch_size_;
6866
return modelbox::STATUS_SUCCESS;
6967
}
7068

@@ -126,17 +124,25 @@ modelbox::Status AtcInference::CheckModelIO(
126124
unit_input_list.end());
127125
std::set<std::string> unit_output_set(unit_output_list.begin(),
128126
unit_output_list.end());
127+
128+
auto model_input_size = dynamic_batch_tensor_index_ >= 0
129+
? model_input_list_.size() - 1
130+
: model_input_list_.size();
131+
129132
if (model_input_list_.empty() || model_output_list_.empty() ||
130-
model_input_list_.size() != unit_input_list.size() ||
133+
model_input_size != unit_input_list.size() ||
131134
model_output_list_.size() != unit_output_list.size()) {
132-
MBLOG_ERROR << "Model input[" << model_input_list_.size() << "], output["
135+
MBLOG_ERROR << "Model input[" << model_input_size << "], output["
133136
<< model_output_list_.size() << "], FlowUnit input["
134137
<< unit_input_list.size() << "], output["
135138
<< unit_output_list.size() << "], these io count is bad";
136139
return modelbox::STATUS_BADCONF;
137140
}
138141

139142
for (auto &model_input_name : model_input_list_) {
143+
if (model_input_name == ACL_DYNAMIC_TENSOR_NAME) {
144+
continue;
145+
}
140146
if (unit_input_set.find(model_input_name) == unit_input_set.end()) {
141147
MBLOG_ERROR << "Model miss input [" << model_input_name
142148
<< "] in graph config";
@@ -161,7 +167,8 @@ void AtcInference::ReadModelInfo() {
161167
std::stringstream model_info;
162168
model_info << "Model:" << model_file_ << std::endl;
163169
auto *desc_ptr = model_desc_.get();
164-
LogBatchInfo(desc_ptr, model_info);
170+
size_t max_batch_size = 1;
171+
SaveBatchInfo(desc_ptr, model_info, max_batch_size);
165172
model_info << "Input:" << std::endl;
166173
aclmdlIODims dims;
167174
for (size_t i = 0; i < input_num; ++i) {
@@ -178,6 +185,17 @@ void AtcInference::ReadModelInfo() {
178185
model_input_size_.push_back(size);
179186
auto format = aclmdlGetInputFormat(desc_ptr, i);
180187
auto data_type = aclmdlGetInputDataType(desc_ptr, i);
188+
189+
if (name == ACL_DYNAMIC_TENSOR_NAME) {
190+
dynamic_batch_tensor_index_ = i;
191+
ret = aclrtMalloc(&dynamic_batch_mem_ptr_, model_input_size_[i],
192+
aclrtMemMallocPolicy::ACL_MEM_MALLOC_NORMAL_ONLY);
193+
if (ret != ACL_SUCCESS || dynamic_batch_mem_ptr_ == nullptr) {
194+
MBLOG_ERROR << "malloc acl memory failed, size: " << size;
195+
return;
196+
}
197+
}
198+
181199
LogTensorInfo(desc_ptr, i, dims, size, format, data_type, model_info);
182200
}
183201

@@ -194,7 +212,14 @@ void AtcInference::ReadModelInfo() {
194212
std::string name = aclmdlGetOutputNameByIndex(desc_ptr, i);
195213
model_output_list_.push_back(name);
196214
auto size = aclmdlGetOutputSizeByIndex(desc_ptr, i);
197-
model_output_size_.push_back(size);
215+
if (dynamic_batch_tensor_index_ >= 0 && dims.dimCount > 0 &&
216+
max_batch_size != (size_t)dims.dims[0]) {
217+
MBLOG_ERROR << "model output tensor [" << name
218+
<< "] dims error, first dims: " << dims.dims[0]
219+
<< " is not same with max_batch_size: " << max_batch_size;
220+
}
221+
222+
model_output_size_.push_back(size / max_batch_size);
198223
auto format = aclmdlGetOutputFormat(desc_ptr, i);
199224
auto data_type = aclmdlGetOutputDataType(desc_ptr, i);
200225
output_data_type_.push_back(GetModelBoxDataType(data_type));
@@ -213,15 +238,21 @@ void AtcInference::SaveOutputShape(const aclmdlIODims &dims) {
213238
output_shape_.push_back(shape);
214239
}
215240

216-
void AtcInference::LogBatchInfo(aclmdlDesc *desc_ptr,
217-
std::stringstream &model_info) {
241+
void AtcInference::SaveBatchInfo(aclmdlDesc *desc_ptr,
242+
std::stringstream &model_info,
243+
size_t &max_batch_size) {
218244
aclmdlBatch batch;
219245
auto ret = aclmdlGetDynamicBatch(desc_ptr, &batch);
220246
if (ret != ACL_ERROR_NONE) {
221247
model_info << "Get dynamic batch failed, ret " << ret;
222248
} else {
223249
model_info << "Dynamic batch:[";
250+
size_t size = 0;
224251
for (size_t i = 0; i < batch.batchCount; ++i) {
252+
dynamic_batch_set_.emplace(batch.batch[i]);
253+
if (size < batch.batch[i]) {
254+
size = batch.batch[i];
255+
}
225256
model_info << batch.batch[i];
226257
if (i + 1 == batch.batchCount) {
227258
model_info << "]";
@@ -232,6 +263,8 @@ void AtcInference::LogBatchInfo(aclmdlDesc *desc_ptr,
232263

233264
if (batch.batchCount == 0) {
234265
model_info << "]";
266+
} else {
267+
max_batch_size = size;
235268
}
236269
}
237270
model_info << std::endl;
@@ -290,6 +323,31 @@ modelbox::ModelBoxDataType AtcInference::GetModelBoxDataType(
290323
return modelbox::ModelBoxDataType::MODELBOX_TYPE_INVALID;
291324
}
292325

326+
modelbox::Status AtcInference::GetCurrentBatchSize(
327+
std::shared_ptr<modelbox::DataContext> &data_ctx, size_t &batch_size) {
328+
if (model_input_list_.size() == 0) {
329+
MBLOG_ERROR << "model_input_list_ is empty ";
330+
return modelbox::STATUS_FAULT;
331+
}
332+
333+
auto buffer_list = data_ctx->Input()->at(model_input_list_[0]);
334+
if (buffer_list == nullptr) {
335+
MBLOG_ERROR << "get current batch size failed ";
336+
return modelbox::STATUS_FAULT;
337+
}
338+
339+
if ((dynamic_batch_tensor_index_ >= 0) &&
340+
(dynamic_batch_set_.find(buffer_list->Size()) ==
341+
dynamic_batch_set_.end())) {
342+
MBLOG_ERROR << "current model is not support input batch_size: "
343+
<< buffer_list->Size();
344+
return modelbox::STATUS_FAULT;
345+
}
346+
347+
batch_size = buffer_list->Size();
348+
return modelbox::STATUS_OK;
349+
}
350+
293351
modelbox::Status AtcInference::Infer(
294352
std::shared_ptr<modelbox::DataContext> &data_ctx, aclrtStream stream) {
295353
auto acl_ret = aclrtSetDevice(device_id_);
@@ -299,24 +357,43 @@ modelbox::Status AtcInference::Infer(
299357
return {modelbox::STATUS_FAULT, "Set device failed"};
300358
}
301359

302-
auto input = CreateDataSet(data_ctx->Input(), model_input_list_);
360+
size_t current_batch_size;
361+
auto ret = GetCurrentBatchSize(data_ctx, current_batch_size);
362+
if (ret != modelbox::STATUS_SUCCESS) {
363+
MBLOG_ERROR << "get current batch size failed";
364+
return {modelbox::STATUS_FAULT, "Get current batch size failed"};
365+
}
366+
367+
auto input =
368+
CreateDataSet(data_ctx->Input(), model_input_list_, current_batch_size);
303369
if (input == nullptr) {
304370
MBLOG_ERROR << "Create input for infer failed";
305371
return {modelbox::STATUS_FAULT, "Create input failed"};
306372
}
307373

308-
auto ret = PrepareOutput(data_ctx);
374+
ret = PrepareOutput(data_ctx, current_batch_size);
309375
if (ret != modelbox::STATUS_SUCCESS) {
310376
MBLOG_ERROR << "Prepare output failed";
311377
return {modelbox::STATUS_FAULT, "Prepare output failed"};
312378
}
313379

314-
auto output = CreateDataSet(data_ctx->Output(), model_output_list_);
380+
auto output =
381+
CreateDataSet(data_ctx->Output(), model_output_list_, current_batch_size);
315382
if (output == nullptr) {
316383
MBLOG_ERROR << "Create output for infer failed";
317384
return {modelbox::STATUS_FAULT, "Create output failed"};
318385
}
319386

387+
if (dynamic_batch_tensor_index_ >= 0) {
388+
acl_ret = aclmdlSetDynamicBatchSize(model_id_, input.get(),
389+
dynamic_batch_tensor_index_,
390+
current_batch_size);
391+
if (acl_ret != ACL_ERROR_NONE) {
392+
MBLOG_ERROR << "aclmdlSetDynamicBatchSize failed, ret " << acl_ret;
393+
return {modelbox::STATUS_FAULT, "Execute acl set batch_size failed"};
394+
}
395+
}
396+
320397
acl_ret = ACL_ERROR_NONE;
321398
if (stream == nullptr) {
322399
acl_ret = aclmdlExecute(model_id_, input.get(), output.get());
@@ -334,13 +411,14 @@ modelbox::Status AtcInference::Infer(
334411
}
335412

336413
modelbox::Status AtcInference::PrepareOutput(
337-
std::shared_ptr<modelbox::DataContext> &data_ctx) {
414+
std::shared_ptr<modelbox::DataContext> &data_ctx,
415+
const size_t &current_batch_size) {
338416
auto output_count = model_output_list_.size();
339417
for (size_t i = 0; i < output_count; ++i) {
340418
auto &name = model_output_list_[i];
341419
auto buffer_list = data_ctx->Output(name);
342420
auto &size = model_output_size_[i];
343-
std::vector<size_t> shape(batch_size_, size);
421+
std::vector<size_t> shape(current_batch_size, size);
344422
buffer_list->Build(shape);
345423
buffer_list->Set("shape", output_shape_[i]);
346424
buffer_list->Set("type", output_data_type_[i]);
@@ -351,7 +429,7 @@ modelbox::Status AtcInference::PrepareOutput(
351429

352430
std::shared_ptr<aclmdlDataset> AtcInference::CreateDataSet(
353431
const std::shared_ptr<modelbox::BufferListMap> &buffer_list_map,
354-
std::vector<std::string> &name_list) {
432+
std::vector<std::string> &name_list, const size_t &current_batch_size) {
355433
auto *data_set_ptr = aclmdlCreateDataset();
356434
if (data_set_ptr == nullptr) {
357435
MBLOG_ERROR << "aclmdlCreateDataset return null";
@@ -368,15 +446,32 @@ std::shared_ptr<aclmdlDataset> AtcInference::CreateDataSet(
368446
});
369447

370448
for (auto &tensor_name : name_list) {
371-
auto buffer_list = buffer_list_map->at(tensor_name);
372-
if (buffer_list == nullptr) {
373-
MBLOG_ERROR << "Create data set for tensor " << tensor_name
374-
<< " failed, buffer list is null";
375-
return nullptr;
449+
void *mem_ptr = nullptr;
450+
size_t size;
451+
if (tensor_name != ACL_DYNAMIC_TENSOR_NAME) {
452+
auto buffer_list = buffer_list_map->at(tensor_name);
453+
if (buffer_list == nullptr) {
454+
MBLOG_ERROR << "Create data set for tensor " << tensor_name
455+
<< " failed, buffer list is null";
456+
return nullptr;
457+
}
458+
459+
if (current_batch_size != buffer_list->Size()) {
460+
MBLOG_ERROR << "buffer bacth_size is not same, first bacth_size: "
461+
<< current_batch_size
462+
<< " , current tensor: " << tensor_name
463+
<< " bacth_size:" << buffer_list->Size();
464+
return nullptr;
465+
}
466+
467+
mem_ptr = const_cast<void *>(buffer_list->ConstData());
468+
size = buffer_list->GetBytes();
469+
} else {
470+
size = model_input_size_[dynamic_batch_tensor_index_];
471+
mem_ptr = dynamic_batch_mem_ptr_;
376472
}
377473

378-
auto *data_buffer = aclCreateDataBuffer(
379-
const_cast<void *>(buffer_list->ConstData()), buffer_list->GetBytes());
474+
auto *data_buffer = aclCreateDataBuffer(mem_ptr, size);
380475
if (data_buffer == nullptr) {
381476
MBLOG_ERROR << "Create data set buffer for tensor " << tensor_name
382477
<< "failed";
@@ -411,5 +506,9 @@ modelbox::Status AtcInference::Deinit() {
411506
return modelbox::STATUS_FAULT;
412507
}
413508

509+
if (dynamic_batch_mem_ptr_ != nullptr) {
510+
aclrtFree(dynamic_batch_mem_ptr_);
511+
}
512+
414513
return modelbox::STATUS_SUCCESS;
415514
}

src/drivers/devices/ascend/flowunit/inference/atc_inference.h

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,8 @@ class AtcInference {
5656

5757
void SaveOutputShape(const aclmdlIODims &dims);
5858

59-
void LogBatchInfo(aclmdlDesc *desc_ptr, std::stringstream &model_info);
59+
void SaveBatchInfo(aclmdlDesc *desc_ptr, std::stringstream &model_info,
60+
size_t &max_batch_size);
6061

6162
void LogTensorInfo(aclmdlDesc *desc_ptr, size_t index, aclmdlIODims &dims,
6263
size_t size, aclFormat format, aclDataType data_type,
@@ -69,15 +70,21 @@ class AtcInference {
6970
modelbox::ModelBoxDataType GetModelBoxDataType(aclDataType data_type);
7071

7172
modelbox::Status PrepareOutput(
72-
std::shared_ptr<modelbox::DataContext> &data_ctx);
73+
std::shared_ptr<modelbox::DataContext> &data_ctx,
74+
const size_t &current_batch_size);
7375

7476
std::shared_ptr<aclmdlDataset> CreateDataSet(
7577
const std::shared_ptr<modelbox::BufferListMap> &buffer_list_map,
76-
std::vector<std::string> &name_list);
78+
std::vector<std::string> &name_list, const size_t &current_batch_size);
79+
80+
modelbox::Status GetCurrentBatchSize(
81+
std::shared_ptr<modelbox::DataContext> &data_ctx, size_t &batch_size);
7782

7883
int32_t device_id_{0};
7984
std::string model_file_;
80-
int32_t batch_size_{1};
85+
int32_t dynamic_batch_tensor_index_{-1};
86+
void *dynamic_batch_mem_ptr_{nullptr};
87+
std::set<size_t> dynamic_batch_set_;
8188

8289
uint32_t model_id_{0};
8390
bool is_model_load_{false};

0 commit comments

Comments
 (0)