1
1
# coding=utf-8
2
2
import threading
3
- from typing import Dict
3
+ from typing import Dict , Optional , List , Any
4
4
5
5
from langchain_community .embeddings import XinferenceEmbeddings
6
+ from langchain_core .embeddings import Embeddings
6
7
7
8
from setting .models_provider .base_model_provider import MaxKBBaseModel
8
9
9
10
10
- class XinferenceEmbedding (MaxKBBaseModel , XinferenceEmbeddings ):
11
+ class XinferenceEmbedding (MaxKBBaseModel , Embeddings ):
12
+ client : Any
13
+ server_url : Optional [str ]
14
+ """URL of the xinference server"""
15
+ model_uid : Optional [str ]
16
+ """UID of the launched model"""
17
+
11
18
@staticmethod
12
19
def new_instance (model_type , model_name , model_credential : Dict [str , object ], ** model_kwargs ):
13
20
return XinferenceEmbedding (
14
21
model_uid = model_name ,
15
22
server_url = model_credential .get ('api_base' ),
23
+ api_key = model_credential .get ('api_key' ),
16
24
)
17
25
18
26
def down_model (self ):
@@ -22,3 +30,63 @@ def start_down_model_thread(self):
22
30
thread = threading .Thread (target = self .down_model )
23
31
thread .daemon = True
24
32
thread .start ()
33
+
34
+ def __init__ (
35
+ self , server_url : Optional [str ] = None , model_uid : Optional [str ] = None ,
36
+ api_key : Optional [str ] = None
37
+ ):
38
+ try :
39
+ from xinference .client import RESTfulClient
40
+ except ImportError :
41
+ try :
42
+ from xinference_client import RESTfulClient
43
+ except ImportError as e :
44
+ raise ImportError (
45
+ "Could not import RESTfulClient from xinference. Please install it"
46
+ " with `pip install xinference` or `pip install xinference_client`."
47
+ ) from e
48
+
49
+ if server_url is None :
50
+ raise ValueError ("Please provide server URL" )
51
+
52
+ if model_uid is None :
53
+ raise ValueError ("Please provide the model UID" )
54
+
55
+ self .server_url = server_url
56
+
57
+ self .model_uid = model_uid
58
+
59
+ self .api_key = api_key
60
+
61
+ self .client = RESTfulClient (server_url , api_key )
62
+
63
+ def embed_documents (self , texts : List [str ]) -> List [List [float ]]:
64
+ """Embed a list of documents using Xinference.
65
+ Args:
66
+ texts: The list of texts to embed.
67
+ Returns:
68
+ List of embeddings, one for each text.
69
+ """
70
+
71
+ model = self .client .get_model (self .model_uid )
72
+
73
+ embeddings = [
74
+ model .create_embedding (text )["data" ][0 ]["embedding" ] for text in texts
75
+ ]
76
+ return [list (map (float , e )) for e in embeddings ]
77
+
78
+ def embed_query (self , text : str ) -> List [float ]:
79
+ """Embed a query of documents using Xinference.
80
+ Args:
81
+ text: The text to embed.
82
+ Returns:
83
+ Embeddings for the text.
84
+ """
85
+
86
+ model = self .client .get_model (self .model_uid )
87
+
88
+ embedding_res = model .create_embedding (text )
89
+
90
+ embedding = embedding_res ["data" ][0 ]["embedding" ]
91
+
92
+ return list (map (float , embedding ))
0 commit comments