Skip to content

flagEmbedding做向量化是否自带支持高并发 #1536

@Stefan3Zz

Description

@Stefan3Zz

请问使用flagEmbedding做向量化的时候,使用bge-m3模型,在高并发下有最佳实践可以参考吗,目前我自己包装了一个使用flagembedding服务来做向量化,在使用多线程的情况下遇到奇奇怪怪的问题,以下是我的代码

`
import os
import traceback
from concurrent.futures import ThreadPoolExecutor

from fastapi import FastAPI, HTTPException
from fastapi.responses import JSONResponse
from fastapi.encoders import jsonable_encoder
import uvicorn
import asyncio

from log.log_info import Logging
from req.text_input import TextInput, RerankInput, OcrInput, ModelInput

创建一个线程池执行器,用于处理请求

executor = ThreadPoolExecutor(max_workers=10)

通用函数,用于使用线程池执行同步任务

async def run_in_threadpool(func, *args):
loop = asyncio.get_event_loop()
return await loop.run_in_executor(executor, func, *args)

device_type = os.environ.get("DEVICE_TYPE", None)
target_devices = os.environ.get("TARGET_DEVICES", None)
logger = Logging("app").get_logger()

try:
from services.key_bert_service import KeyBertService
from services.bert_doc_segmentation import BertDocSegmentation
from services.trie_service import TrieService
from services.en_words_segmentation import ENWordsSegmentation

bert_doc_segmentation = BertDocSegmentation()
key_bert_service = KeyBertService()
trie_service = TrieService()
en_segment = ENWordsSegmentation()
logger.info("KeyBertService, BertDocSegmentation, TrieService and ENWordsSegmentation initialized successfully")

except ImportError:
logger.error("Exception keyBertService: %s", traceback.format_exc())
logger.warn(
"Warning: KeyBertService, BertDocSegmentation could not be imported. The feature requiring this class may not "
"work.")

try:
from services.bge_m3_embedding_service import BgeM3EmbeddingService

bge_m3_embedding_service = BgeM3EmbeddingService(device_type)
logger.info("BgeM3EmbeddingService initialized successfully")

except ImportError:
# 处理异常情况
logger.error("Exception occurred: %s", traceback.format_exc())
logger.warn(
"Warning: BgeM3EmbeddingService could not be imported. The feature requiring this class may not work.")

定义 API 路由和逻辑

@app.post("/v1/extract_keywords")
async def extract_keywords(input: TextInput):
try:
# rs = key_bert_service.get_keywords(input.input, input.model, input.top_n)
rs = await run_in_threadpool(key_bert_service.get_keywords, input.input, input.model, input.top_n)
return resp(rs)
except Exception as e:
# 处理异常情况
raise HTTPException(status_code=500, detail=str(e))

@app.post("/v1/bge/embeddings")
async def trie_reload(input: TextInput):
try:
# trie_service.reload(input.model)
rs = await run_in_threadpool(bge_m3_embedding_service.embeddings, input.input, input.model)
return resp(rs)
except Exception as e:
logger.error("Exception occurred: %s", traceback.format_exc())
raise HTTPException(status_code=500, detail=str(e))

def resp(data, code=200, msg="OK"):
return {
"status": code,
"message": msg,
"data": data
}

def resp_json(data, code=200, msg="OK"):
return JSONResponse(content=jsonable_encoder({
"status": code,
"message": msg,
"data": data
}))

if name == "main":
uvicorn.run(app, host="0.0.0.0", port=5001, loop="asyncio")`

以上是接口层的代码

`
import os
from typing import List, Union

from config.app_config import InferenceConfig
from log.log_info import Logging
from FlagEmbedding import BGEM3FlagModel

logger = Logging("embedding").get_logger()

try:
from torch_npu.contrib import transfer_to_npu
import torch_npu
except ImportError:
logger.warn("Warning: Embedding Npu IS not available")

model_path = None
if "EMBEDDING_PATH" in os.environ:
model_path = os.environ["EMBEDDING_PATH"]

class BgeM3EmbeddingService:
def init(self, deviceType: str):
self.deviceType = deviceType
self.models = {}

def embeddings(self, query: Union[str, List[str]], model_name: str):
    embed_model = self._load_model(model_name)
    output = embed_model.encode(query, return_dense=True, return_sparse=True, return_colbert_vecs=False)
    dense_vec = output['dense_vecs'].tolist()
    sparse_vec = output['lexical_weights']
    if isinstance(query, str):
        return {
            'dense': dense_vec,
            'sparse': {k: float(v) for k, v in sparse_vec.items()}
        }

    else:
        sparse_list = []
        for sparse in sparse_vec:
            sparse_list.append({k: float(v) for k, v in sparse.items()})
        return {
            'dense': dense_vec,
            'sparse': sparse_list
        }

def _load_model(self, model_name: str):
    if model_name not in self.models:
        logger.info(f"Loading model: {model_name}")
        # 加载 模型
        final_model_path = model_path
        if final_model_path is None:
            final_model_path = InferenceConfig.models[model_name]
        if self.deviceType is None:
            model = BGEM3FlagModel(final_model_path,low_cpu_mem_usage=False)
        else:
            devices=self.deviceType.split(",")
            model = BGEM3FlagModel(final_model_path, devices=devices,low_cpu_mem_usage=False)
        self.models[model_name] = model
    return self.models[model_name]

`

这块调用flagembedding层的代码,可以看出来我的线程数是10,在这个情况下,并发量一旦很高时候就会遇到
NotImplementedError: Cannot copy out of meta tensor; no data! Please use torch.nn.Module.to_empty() … when moving module from meta …
上述的报错,导致后续所有调用都会报错,而且我看gpu显存占用并没有打满,只用了一半不到,使用的是a100系列的显卡,
请问可以帮忙看下这个问题吗

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions