Skip to content

Commit 6ef8fc3

Browse files
hzhyhx1117pymumu
authored andcommitted
Torch: add load model mutex
Signed-off-by: howe <hzhyhx1117@163.com>
1 parent eed8570 commit 6ef8fc3

File tree

1 file changed

+4
-0
lines changed

1 file changed

+4
-0
lines changed

src/drivers/devices/cuda/flowunit/torch/torch_inference_flowunit.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,14 @@
2121
#include <modelbox/base/crypto.h>
2222

2323
#include <fstream>
24+
#include <mutex>
2425

2526
#include "modelbox/device/cuda/device_cuda.h"
2627
#include "modelbox/type.h"
2728
#include "virtualdriver_inference.h"
2829

30+
static std::mutex torch_load_mutex;
31+
2932
static std::map<std::string, c10::ScalarType> type_map = {
3033
{"FLOAT", torch::kFloat32}, {"DOUBLE", torch::kFloat64},
3134
{"INT", torch::kInt32}, {"UINT8", torch::kUInt8},
@@ -68,6 +71,7 @@ void TorchInferenceFlowUnit::FillOutput(
6871
modelbox::Status TorchInferenceFlowUnit::LoadModel(
6972
const std::string &model_path,
7073
const std::shared_ptr<modelbox::Configuration> &config) {
74+
std::lock_guard<std::mutex> lck(torch_load_mutex);
7175
try {
7276
MBLOG_DEBUG << "model_path: " << model_path;
7377
auto drivers_ptr = GetBindDevice()->GetDeviceManager()->GetDrivers();

0 commit comments

Comments
 (0)