Skip to content

Commit ddad340

Browse files
committed
feat: 讯飞图片模型
1 parent f318f2d commit ddad340

File tree

9 files changed

+240
-4
lines changed

9 files changed

+240
-4
lines changed

apps/setting/models_provider/base_model_provider.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,7 @@ class ModelTypeConst(Enum):
149149
EMBEDDING = {'code': 'EMBEDDING', 'message': '向量模型'}
150150
STT = {'code': 'STT', 'message': '语音识别'}
151151
TTS = {'code': 'TTS', 'message': '语音合成'}
152+
IMAGE = {'code': 'IMAGE', 'message': '图片理解'}
152153
RERANKER = {'code': 'RERANKER', 'message': '重排模型'}
153154

154155

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# coding=utf-8
2+
from abc import abstractmethod
3+
4+
from pydantic import BaseModel
5+
6+
7+
class BaseImage(BaseModel):
8+
@abstractmethod
9+
def check_auth(self):
10+
pass
11+
12+
@abstractmethod
13+
def image_understand(self, image_file, text):
14+
pass
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
# coding=utf-8
2+
3+
from typing import Dict
4+
5+
from common import forms
6+
from common.exception.app_exception import AppApiException
7+
from common.forms import BaseForm
8+
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
9+
10+
11+
class XunFeiImageModelCredential(BaseForm, BaseModelCredential):
12+
spark_api_url = forms.TextInputField('API 域名', required=True, default_value='wss://spark-api.cn-huabei-1.xf-yun.com/v2.1/image')
13+
spark_app_id = forms.TextInputField('APP ID', required=True)
14+
spark_api_key = forms.PasswordInputField("API Key", required=True)
15+
spark_api_secret = forms.PasswordInputField('API Secret', required=True)
16+
17+
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
18+
raise_exception=False):
19+
model_type_list = provider.get_model_type_list()
20+
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
21+
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
22+
23+
for key in ['spark_api_url', 'spark_app_id', 'spark_api_key', 'spark_api_secret']:
24+
if key not in model_credential:
25+
if raise_exception:
26+
raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段')
27+
else:
28+
return False
29+
try:
30+
model = provider.get_model(model_type, model_name, model_credential)
31+
model.check_auth()
32+
except Exception as e:
33+
if isinstance(e, AppApiException):
34+
raise e
35+
if raise_exception:
36+
raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}')
37+
else:
38+
return False
39+
return True
40+
41+
def encryption_dict(self, model: Dict[str, object]):
42+
return {**model, 'spark_api_secret': super().encryption(model.get('spark_api_secret', ''))}
43+
44+
45+
def get_model_params_setting_form(self, model_name):
46+
pass
Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
# coding=utf-8
2+
3+
import asyncio
4+
import base64
5+
import datetime
6+
import hashlib
7+
import hmac
8+
import json
9+
import os
10+
import ssl
11+
from datetime import datetime, UTC
12+
from typing import Dict
13+
from urllib.parse import urlencode
14+
from urllib.parse import urlparse
15+
16+
import websockets
17+
18+
from setting.models_provider.base_model_provider import MaxKBBaseModel
19+
from setting.models_provider.impl.base_image import BaseImage
20+
21+
ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
22+
ssl_context.check_hostname = False
23+
ssl_context.verify_mode = ssl.CERT_NONE
24+
25+
26+
class XFSparkImage(MaxKBBaseModel, BaseImage):
27+
spark_app_id: str
28+
spark_api_key: str
29+
spark_api_secret: str
30+
spark_api_url: str
31+
params: dict
32+
33+
# 初始化
34+
def __init__(self, **kwargs):
35+
super().__init__(**kwargs)
36+
self.spark_api_url = kwargs.get('spark_api_url')
37+
self.spark_app_id = kwargs.get('spark_app_id')
38+
self.spark_api_key = kwargs.get('spark_api_key')
39+
self.spark_api_secret = kwargs.get('spark_api_secret')
40+
self.params = kwargs.get('params')
41+
42+
@staticmethod
43+
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
44+
optional_params = {'params': {}}
45+
for key, value in model_kwargs.items():
46+
if key not in ['model_id', 'use_local', 'streaming']:
47+
optional_params['params'][key] = value
48+
return XFSparkImage(
49+
spark_app_id=model_credential.get('spark_app_id'),
50+
spark_api_key=model_credential.get('spark_api_key'),
51+
spark_api_secret=model_credential.get('spark_api_secret'),
52+
spark_api_url=model_credential.get('spark_api_url'),
53+
**optional_params
54+
)
55+
56+
def create_url(self):
57+
url = self.spark_api_url
58+
host = urlparse(url).hostname
59+
# 生成RFC1123格式的时间戳
60+
gmt_format = '%a, %d %b %Y %H:%M:%S GMT'
61+
date = datetime.now(UTC).strftime(gmt_format)
62+
63+
# 拼接字符串
64+
signature_origin = "host: " + host + "\n"
65+
signature_origin += "date: " + date + "\n"
66+
signature_origin += "GET " + "/v2.1/image " + "HTTP/1.1"
67+
# 进行hmac-sha256进行加密
68+
signature_sha = hmac.new(self.spark_api_secret.encode('utf-8'), signature_origin.encode('utf-8'),
69+
digestmod=hashlib.sha256).digest()
70+
signature_sha = base64.b64encode(signature_sha).decode(encoding='utf-8')
71+
72+
authorization_origin = "api_key=\"%s\", algorithm=\"%s\", headers=\"%s\", signature=\"%s\"" % (
73+
self.spark_api_key, "hmac-sha256", "host date request-line", signature_sha)
74+
authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8')
75+
# 将请求的鉴权参数组合为字典
76+
v = {
77+
"authorization": authorization,
78+
"date": date,
79+
"host": host
80+
}
81+
# 拼接鉴权参数,生成url
82+
url = url + '?' + urlencode(v)
83+
# print("date: ",date)
84+
# print("v: ",v)
85+
# 此处打印出建立连接时候的url,参考本demo的时候可取消上方打印的注释,比对相同参数时生成的url与自己代码生成的url是否一致
86+
# print('websocket url :', url)
87+
return url
88+
89+
def check_auth(self):
90+
cwd = os.path.dirname(os.path.abspath(__file__))
91+
with open(f'{cwd}/img_1.png', 'rb') as f:
92+
self.image_understand(f,"一句话概述这个图片")
93+
94+
def image_understand(self, image_file, question):
95+
async def handle():
96+
async with websockets.connect(self.create_url(), max_size=1000000000, ssl=ssl_context) as ws:
97+
# 发送 full client request
98+
await self.send(ws, image_file, question)
99+
return await self.handle_message(ws)
100+
101+
return asyncio.run(handle())
102+
103+
# 收到websocket消息的处理
104+
@staticmethod
105+
async def handle_message(ws):
106+
# print(message)
107+
answer = ''
108+
while True:
109+
res = await ws.recv()
110+
data = json.loads(res)
111+
code = data['header']['code']
112+
if code != 0:
113+
return f'请求错误: {code}, {data}'
114+
else:
115+
choices = data["payload"]["choices"]
116+
status = choices["status"]
117+
content = choices["text"][0]["content"]
118+
# print(content, end="")
119+
answer += content
120+
# print(1)
121+
if status == 2:
122+
break
123+
return answer
124+
125+
async def send(self, ws, image_file, question):
126+
text = [
127+
{"role": "user", "content": str(base64.b64encode(image_file.read()), 'utf-8'), "content_type": "image"},
128+
{"role": "user", "content": question}
129+
]
130+
131+
data = {
132+
"header": {
133+
"app_id": self.spark_app_id
134+
},
135+
"parameter": {
136+
"chat": {
137+
"domain": "image",
138+
"temperature": 0.5,
139+
"top_k": 4,
140+
"max_tokens": 2028,
141+
"auditing": "default"
142+
}
143+
},
144+
"payload": {
145+
"message": {
146+
"text": text
147+
}
148+
}
149+
}
150+
151+
d = json.dumps(data)
152+
await ws.send(d)
153+
154+
def is_cache_model(self):
155+
return False
156+
157+
@staticmethod
158+
def get_len(text):
159+
length = 0
160+
for content in text:
161+
temp = content["content"]
162+
leng = len(temp)
163+
length += leng
164+
return length
165+
166+
def check_len(self, text):
167+
print("text-content-tokens:", self.get_len(text[1:]))
168+
while (self.get_len(text[1:]) > 8000):
169+
del text[1]
170+
return text
Loading

apps/setting/models_provider/impl/xf_model_provider/model/stt.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import json
1111
import logging
1212
import os
13-
from datetime import datetime
13+
from datetime import datetime, UTC
1414
from typing import Dict
1515
from urllib.parse import urlencode, urlparse
1616
import ssl
@@ -63,7 +63,7 @@ def create_url(self):
6363
host = urlparse(url).hostname
6464
# 生成RFC1123格式的时间戳
6565
gmt_format = '%a, %d %b %Y %H:%M:%S GMT'
66-
date = datetime.utcnow().strftime(gmt_format)
66+
date = datetime.now(UTC).strftime(gmt_format)
6767

6868
# 拼接字符串
6969
signature_origin = "host: " + host + "\n"

apps/setting/models_provider/impl/xf_model_provider/model/tts.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import json
1313
import logging
1414
import os
15-
from datetime import datetime
15+
from datetime import datetime, UTC
1616
from typing import Dict
1717
from urllib.parse import urlencode, urlparse
1818
import ssl
@@ -67,7 +67,7 @@ def create_url(self):
6767
host = urlparse(url).hostname
6868
# 生成RFC1123格式的时间戳
6969
gmt_format = '%a, %d %b %Y %H:%M:%S GMT'
70-
date = datetime.utcnow().strftime(gmt_format)
70+
date = datetime.now(UTC).strftime(gmt_format)
7171

7272
# 拼接字符串
7373
signature_origin = "host: " + host + "\n"

apps/setting/models_provider/impl/xf_model_provider/xf_model_provider.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,12 @@
1313
from setting.models_provider.base_model_provider import ModelProvideInfo, ModelTypeConst, ModelInfo, IModelProvider, \
1414
ModelInfoManage
1515
from setting.models_provider.impl.xf_model_provider.credential.embedding import XFEmbeddingCredential
16+
from setting.models_provider.impl.xf_model_provider.credential.image import XunFeiImageModelCredential
1617
from setting.models_provider.impl.xf_model_provider.credential.llm import XunFeiLLMModelCredential
1718
from setting.models_provider.impl.xf_model_provider.credential.stt import XunFeiSTTModelCredential
1819
from setting.models_provider.impl.xf_model_provider.credential.tts import XunFeiTTSModelCredential
1920
from setting.models_provider.impl.xf_model_provider.model.embedding import XFEmbedding
21+
from setting.models_provider.impl.xf_model_provider.model.image import XFSparkImage
2022
from setting.models_provider.impl.xf_model_provider.model.llm import XFChatSparkLLM
2123
from setting.models_provider.impl.xf_model_provider.model.stt import XFSparkSpeechToText
2224
from setting.models_provider.impl.xf_model_provider.model.tts import XFSparkTextToSpeech
@@ -26,6 +28,7 @@
2628

2729
qwen_model_credential = XunFeiLLMModelCredential()
2830
stt_model_credential = XunFeiSTTModelCredential()
31+
image_model_credential = XunFeiImageModelCredential()
2932
tts_model_credential = XunFeiTTSModelCredential()
3033
embedding_model_credential = XFEmbeddingCredential()
3134
model_info_list = [
@@ -34,6 +37,7 @@
3437
ModelInfo('generalv2', '', ModelTypeConst.LLM, qwen_model_credential, XFChatSparkLLM),
3538
ModelInfo('iat', '中英文识别', ModelTypeConst.STT, stt_model_credential, XFSparkSpeechToText),
3639
ModelInfo('tts', '', ModelTypeConst.TTS, tts_model_credential, XFSparkTextToSpeech),
40+
ModelInfo('image', '', ModelTypeConst.IMAGE, image_model_credential, XFSparkImage),
3741
ModelInfo('embedding', '', ModelTypeConst.EMBEDDING, embedding_model_credential, XFEmbedding)
3842
]
3943

ui/src/views/template/index.vue

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,7 @@
132132
<el-option label="重排模型" value="RERANKER" />
133133
<el-option label="语音识别" value="STT" />
134134
<el-option label="语音合成" value="TTS" />
135+
<el-option label="图片理解" value="IMAGE" />
135136
</el-select>
136137
</div>
137138
</div>

0 commit comments

Comments
 (0)