Skip to content

Commit fc015d1

Browse files
committed
fix: 修复xinference向量模型添加失败的缺陷
1 parent f30d3d7 commit fc015d1

File tree

2 files changed

+73
-3
lines changed

2 files changed

+73
-3
lines changed

apps/setting/models_provider/impl/xinference_model_provider/credential/embedding.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@ def is_valid(self, model_type: str, model_name, model_credential: Dict[str, obje
1515
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
1616
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
1717
try:
18-
model_list = provider.get_base_model_list(model_credential.get('api_base'), 'embedding')
18+
model_list = provider.get_base_model_list(model_credential.get('api_base'), model_credential.get('api_key'),
19+
'embedding')
1920
except Exception as e:
2021
raise AppApiException(ValidCode.valid_error.value, "API 域名无效")
2122
exist = provider.get_model_info_by_name(model_list, model_name)
@@ -36,3 +37,4 @@ def build_model(self, model_info: Dict[str, object]):
3637
return self
3738

3839
api_base = forms.TextInputField('API 域名', required=True)
40+
api_key = forms.PasswordInputField('API Key', required=True)
Lines changed: 70 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,26 @@
11
# coding=utf-8
22
import threading
3-
from typing import Dict
3+
from typing import Dict, Optional, List, Any
44

55
from langchain_community.embeddings import XinferenceEmbeddings
6+
from langchain_core.embeddings import Embeddings
67

78
from setting.models_provider.base_model_provider import MaxKBBaseModel
89

910

10-
class XinferenceEmbedding(MaxKBBaseModel, XinferenceEmbeddings):
11+
class XinferenceEmbedding(MaxKBBaseModel, Embeddings):
12+
client: Any
13+
server_url: Optional[str]
14+
"""URL of the xinference server"""
15+
model_uid: Optional[str]
16+
"""UID of the launched model"""
17+
1118
@staticmethod
1219
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
1320
return XinferenceEmbedding(
1421
model_uid=model_name,
1522
server_url=model_credential.get('api_base'),
23+
api_key=model_credential.get('api_key'),
1624
)
1725

1826
def down_model(self):
@@ -22,3 +30,63 @@ def start_down_model_thread(self):
2230
thread = threading.Thread(target=self.down_model)
2331
thread.daemon = True
2432
thread.start()
33+
34+
def __init__(
35+
self, server_url: Optional[str] = None, model_uid: Optional[str] = None,
36+
api_key: Optional[str] = None
37+
):
38+
try:
39+
from xinference.client import RESTfulClient
40+
except ImportError:
41+
try:
42+
from xinference_client import RESTfulClient
43+
except ImportError as e:
44+
raise ImportError(
45+
"Could not import RESTfulClient from xinference. Please install it"
46+
" with `pip install xinference` or `pip install xinference_client`."
47+
) from e
48+
49+
if server_url is None:
50+
raise ValueError("Please provide server URL")
51+
52+
if model_uid is None:
53+
raise ValueError("Please provide the model UID")
54+
55+
self.server_url = server_url
56+
57+
self.model_uid = model_uid
58+
59+
self.api_key = api_key
60+
61+
self.client = RESTfulClient(server_url, api_key)
62+
63+
def embed_documents(self, texts: List[str]) -> List[List[float]]:
64+
"""Embed a list of documents using Xinference.
65+
Args:
66+
texts: The list of texts to embed.
67+
Returns:
68+
List of embeddings, one for each text.
69+
"""
70+
71+
model = self.client.get_model(self.model_uid)
72+
73+
embeddings = [
74+
model.create_embedding(text)["data"][0]["embedding"] for text in texts
75+
]
76+
return [list(map(float, e)) for e in embeddings]
77+
78+
def embed_query(self, text: str) -> List[float]:
79+
"""Embed a query of documents using Xinference.
80+
Args:
81+
text: The text to embed.
82+
Returns:
83+
Embeddings for the text.
84+
"""
85+
86+
model = self.client.get_model(self.model_uid)
87+
88+
embedding_res = model.create_embedding(text)
89+
90+
embedding = embedding_res["data"][0]["embedding"]
91+
92+
return list(map(float, embedding))

0 commit comments

Comments
 (0)