-
Notifications
You must be signed in to change notification settings - Fork 800
Description
请问使用flagEmbedding做向量化的时候,使用bge-m3模型,在高并发下有最佳实践可以参考吗,目前我自己包装了一个使用flagembedding服务来做向量化,在使用多线程的情况下遇到奇奇怪怪的问题,以下是我的代码
`
import os
import traceback
from concurrent.futures import ThreadPoolExecutor
from fastapi import FastAPI, HTTPException
from fastapi.responses import JSONResponse
from fastapi.encoders import jsonable_encoder
import uvicorn
import asyncio
from log.log_info import Logging
from req.text_input import TextInput, RerankInput, OcrInput, ModelInput
创建一个线程池执行器,用于处理请求
executor = ThreadPoolExecutor(max_workers=10)
通用函数,用于使用线程池执行同步任务
async def run_in_threadpool(func, *args):
loop = asyncio.get_event_loop()
return await loop.run_in_executor(executor, func, *args)
device_type = os.environ.get("DEVICE_TYPE", None)
target_devices = os.environ.get("TARGET_DEVICES", None)
logger = Logging("app").get_logger()
try:
from services.key_bert_service import KeyBertService
from services.bert_doc_segmentation import BertDocSegmentation
from services.trie_service import TrieService
from services.en_words_segmentation import ENWordsSegmentation
bert_doc_segmentation = BertDocSegmentation()
key_bert_service = KeyBertService()
trie_service = TrieService()
en_segment = ENWordsSegmentation()
logger.info("KeyBertService, BertDocSegmentation, TrieService and ENWordsSegmentation initialized successfully")
except ImportError:
logger.error("Exception keyBertService: %s", traceback.format_exc())
logger.warn(
"Warning: KeyBertService, BertDocSegmentation could not be imported. The feature requiring this class may not "
"work.")
try:
from services.bge_m3_embedding_service import BgeM3EmbeddingService
bge_m3_embedding_service = BgeM3EmbeddingService(device_type)
logger.info("BgeM3EmbeddingService initialized successfully")
except ImportError:
# 处理异常情况
logger.error("Exception occurred: %s", traceback.format_exc())
logger.warn(
"Warning: BgeM3EmbeddingService could not be imported. The feature requiring this class may not work.")
定义 API 路由和逻辑
@app.post("/v1/extract_keywords")
async def extract_keywords(input: TextInput):
try:
# rs = key_bert_service.get_keywords(input.input, input.model, input.top_n)
rs = await run_in_threadpool(key_bert_service.get_keywords, input.input, input.model, input.top_n)
return resp(rs)
except Exception as e:
# 处理异常情况
raise HTTPException(status_code=500, detail=str(e))
@app.post("/v1/bge/embeddings")
async def trie_reload(input: TextInput):
try:
# trie_service.reload(input.model)
rs = await run_in_threadpool(bge_m3_embedding_service.embeddings, input.input, input.model)
return resp(rs)
except Exception as e:
logger.error("Exception occurred: %s", traceback.format_exc())
raise HTTPException(status_code=500, detail=str(e))
def resp(data, code=200, msg="OK"):
return {
"status": code,
"message": msg,
"data": data
}
def resp_json(data, code=200, msg="OK"):
return JSONResponse(content=jsonable_encoder({
"status": code,
"message": msg,
"data": data
}))
if name == "main":
uvicorn.run(app, host="0.0.0.0", port=5001, loop="asyncio")`
以上是接口层的代码
`
import os
from typing import List, Union
from config.app_config import InferenceConfig
from log.log_info import Logging
from FlagEmbedding import BGEM3FlagModel
logger = Logging("embedding").get_logger()
try:
from torch_npu.contrib import transfer_to_npu
import torch_npu
except ImportError:
logger.warn("Warning: Embedding Npu IS not available")
model_path = None
if "EMBEDDING_PATH" in os.environ:
model_path = os.environ["EMBEDDING_PATH"]
class BgeM3EmbeddingService:
def init(self, deviceType: str):
self.deviceType = deviceType
self.models = {}
def embeddings(self, query: Union[str, List[str]], model_name: str):
embed_model = self._load_model(model_name)
output = embed_model.encode(query, return_dense=True, return_sparse=True, return_colbert_vecs=False)
dense_vec = output['dense_vecs'].tolist()
sparse_vec = output['lexical_weights']
if isinstance(query, str):
return {
'dense': dense_vec,
'sparse': {k: float(v) for k, v in sparse_vec.items()}
}
else:
sparse_list = []
for sparse in sparse_vec:
sparse_list.append({k: float(v) for k, v in sparse.items()})
return {
'dense': dense_vec,
'sparse': sparse_list
}
def _load_model(self, model_name: str):
if model_name not in self.models:
logger.info(f"Loading model: {model_name}")
# 加载 模型
final_model_path = model_path
if final_model_path is None:
final_model_path = InferenceConfig.models[model_name]
if self.deviceType is None:
model = BGEM3FlagModel(final_model_path,low_cpu_mem_usage=False)
else:
devices=self.deviceType.split(",")
model = BGEM3FlagModel(final_model_path, devices=devices,low_cpu_mem_usage=False)
self.models[model_name] = model
return self.models[model_name]
`
这块调用flagembedding层的代码,可以看出来我的线程数是10,在这个情况下,并发量一旦很高时候就会遇到
NotImplementedError: Cannot copy out of meta tensor; no data! Please use torch.nn.Module.to_empty() … when moving module from meta …
上述的报错,导致后续所有调用都会报错,而且我看gpu显存占用并没有打满,只用了一半不到,使用的是a100系列的显卡,
请问可以帮忙看下这个问题吗