Skip to content

Commit 8a8305e

Browse files
committed
feat: 高级编排支持文件上传(WIP)
1 parent 72b91be commit 8a8305e

File tree

49 files changed

+1354
-240
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

49 files changed

+1354
-240
lines changed

apps/application/flow/step_node/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,12 @@
1616
from .function_lib_node import *
1717
from .function_node import *
1818
from .reranker_node import *
19+
from .document_extract_node import *
20+
from .image_understand_step_node import *
1921

2022
node_list = [BaseStartStepNode, BaseChatNode, BaseSearchDatasetNode, BaseQuestionNode, BaseConditionNode, BaseReplyNode,
21-
BaseFunctionNodeNode, BaseFunctionLibNodeNode, BaseRerankerNode, BaseApplicationNode]
23+
BaseFunctionNodeNode, BaseFunctionLibNodeNode, BaseRerankerNode, BaseApplicationNode, BaseDocumentExtractNode,
24+
BaseImageUnderstandNode]
2225

2326

2427
def get_node(node_type):
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .impl import *
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
# coding=utf-8
2+
3+
from typing import Type
4+
5+
from rest_framework import serializers
6+
7+
from application.flow.i_step_node import INode, NodeResult
8+
from common.util.field_message import ErrMessage
9+
10+
11+
class DocumentExtractNodeSerializer(serializers.Serializer):
12+
# 需要查询的数据集id列表
13+
file_list = serializers.ListField(required=True, child=serializers.UUIDField(required=True),
14+
error_messages=ErrMessage.list("数据集id列表"))
15+
16+
def is_valid(self, *, raise_exception=False):
17+
super().is_valid(raise_exception=True)
18+
19+
20+
class IDocumentExtractNode(INode):
21+
type = 'document-extract-node'
22+
23+
def get_node_params_serializer_class(self) -> Type[serializers.Serializer]:
24+
return DocumentExtractNodeSerializer
25+
26+
def _run(self):
27+
return self.execute(**self.flow_params_serializer.data)
28+
29+
def execute(self, file_list, **kwargs) -> NodeResult:
30+
pass
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .base_document_extract_node import BaseDocumentExtractNode
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
# coding=utf-8
2+
3+
from application.flow.step_node.document_extract_node.i_document_extract_node import IDocumentExtractNode
4+
5+
6+
class BaseDocumentExtractNode(IDocumentExtractNode):
7+
def execute(self, file_list, **kwargs):
8+
pass
9+
10+
def get_details(self, index: int, **kwargs):
11+
pass
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# coding=utf-8
2+
3+
from .impl import *
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
# coding=utf-8
2+
3+
from typing import Type
4+
5+
from rest_framework import serializers
6+
7+
from application.flow.i_step_node import INode, NodeResult
8+
from common.util.field_message import ErrMessage
9+
10+
11+
class ImageUnderstandNodeSerializer(serializers.Serializer):
12+
model_id = serializers.CharField(required=True, error_messages=ErrMessage.char("模型id"))
13+
system = serializers.CharField(required=False, allow_blank=True, allow_null=True,
14+
error_messages=ErrMessage.char("角色设定"))
15+
prompt = serializers.CharField(required=True, error_messages=ErrMessage.char("提示词"))
16+
# 多轮对话数量
17+
dialogue_number = serializers.IntegerField(required=True, error_messages=ErrMessage.integer("多轮对话数量"))
18+
19+
is_result = serializers.BooleanField(required=False, error_messages=ErrMessage.boolean('是否返回内容'))
20+
21+
image_list = serializers.ListField(required=False, error_messages=ErrMessage.list("图片仅1张"))
22+
23+
24+
class IImageUnderstandNode(INode):
25+
type = 'image-understand-node'
26+
27+
def get_node_params_serializer_class(self) -> Type[serializers.Serializer]:
28+
return ImageUnderstandNodeSerializer
29+
30+
def _run(self):
31+
res = self.workflow_manage.get_reference_field(self.node_params_serializer.data.get('image_list')[0],
32+
self.node_params_serializer.data.get('image_list')[1:])
33+
return self.execute(image=res, **self.node_params_serializer.data, **self.flow_params_serializer.data)
34+
35+
def execute(self, model_id, system, prompt, dialogue_number, history_chat_record, stream, chat_id,
36+
chat_record_id,
37+
image,
38+
**kwargs) -> NodeResult:
39+
pass
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# coding=utf-8
2+
3+
from .base_image_understand_node import BaseImageUnderstandNode
Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
# coding=utf-8
2+
import base64
3+
import os
4+
import time
5+
from functools import reduce
6+
from typing import List, Dict
7+
8+
from django.db.models import QuerySet
9+
from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage
10+
11+
from application.flow.i_step_node import NodeResult, INode
12+
from application.flow.step_node.image_understand_step_node.i_image_understand_node import IImageUnderstandNode
13+
from dataset.models import File
14+
from setting.models_provider.tools import get_model_instance_by_model_user_id
15+
16+
17+
def _write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow, answer: str):
18+
chat_model = node_variable.get('chat_model')
19+
message_tokens = node_variable['usage_metadata']['output_tokens'] if 'usage_metadata' in node_variable else 0
20+
answer_tokens = chat_model.get_num_tokens(answer)
21+
node.context['message_tokens'] = message_tokens
22+
node.context['answer_tokens'] = answer_tokens
23+
node.context['answer'] = answer
24+
node.context['history_message'] = node_variable['history_message']
25+
node.context['question'] = node_variable['question']
26+
node.context['run_time'] = time.time() - node.context['start_time']
27+
if workflow.is_result(node, NodeResult(node_variable, workflow_variable)):
28+
workflow.answer += answer
29+
30+
31+
def write_context_stream(node_variable: Dict, workflow_variable: Dict, node: INode, workflow):
32+
"""
33+
写入上下文数据 (流式)
34+
@param node_variable: 节点数据
35+
@param workflow_variable: 全局数据
36+
@param node: 节点
37+
@param workflow: 工作流管理器
38+
"""
39+
response = node_variable.get('result')
40+
answer = ''
41+
for chunk in response:
42+
answer += chunk.content
43+
yield chunk.content
44+
_write_context(node_variable, workflow_variable, node, workflow, answer)
45+
46+
47+
def write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow):
48+
"""
49+
写入上下文数据
50+
@param node_variable: 节点数据
51+
@param workflow_variable: 全局数据
52+
@param node: 节点实例对象
53+
@param workflow: 工作流管理器
54+
"""
55+
response = node_variable.get('result')
56+
answer = response.content
57+
_write_context(node_variable, workflow_variable, node, workflow, answer)
58+
59+
60+
class BaseImageUnderstandNode(IImageUnderstandNode):
61+
def execute(self, model_id, system, prompt, dialogue_number, history_chat_record, stream, chat_id, chat_record_id,
62+
image,
63+
**kwargs) -> NodeResult:
64+
image_model = get_model_instance_by_model_user_id(model_id, self.flow_params_serializer.data.get('user_id'))
65+
history_message = self.get_history_message(history_chat_record, dialogue_number)
66+
self.context['history_message'] = history_message
67+
question = self.generate_prompt_question(prompt)
68+
self.context['question'] = question.content
69+
# todo 处理上传图片
70+
message_list = self.generate_message_list(image_model, system, prompt, history_message, image)
71+
self.context['message_list'] = message_list
72+
self.context['image_list'] = image
73+
if stream:
74+
r = image_model.stream(message_list)
75+
return NodeResult({'result': r, 'chat_model': image_model, 'message_list': message_list,
76+
'history_message': history_message, 'question': question.content}, {},
77+
_write_context=write_context_stream)
78+
else:
79+
r = image_model.invoke(message_list)
80+
return NodeResult({'result': r, 'chat_model': image_model, 'message_list': message_list,
81+
'history_message': history_message, 'question': question.content}, {},
82+
_write_context=write_context)
83+
84+
@staticmethod
85+
def get_history_message(history_chat_record, dialogue_number):
86+
start_index = len(history_chat_record) - dialogue_number
87+
history_message = reduce(lambda x, y: [*x, *y], [
88+
[history_chat_record[index].get_human_message(), history_chat_record[index].get_ai_message()]
89+
for index in
90+
range(start_index if start_index > 0 else 0, len(history_chat_record))], [])
91+
return history_message
92+
93+
def generate_prompt_question(self, prompt):
94+
return HumanMessage(self.workflow_manage.generate_prompt(prompt))
95+
96+
def generate_message_list(self, image_model, system: str, prompt: str, history_message, image):
97+
if image is not None and len(image) > 0:
98+
file_id = image[0]['file_id']
99+
file = QuerySet(File).filter(id=file_id).first()
100+
base64_image = base64.b64encode(file.get_byte()).decode("utf-8")
101+
messages = [HumanMessage(
102+
content=[
103+
{'type': 'text', 'text': self.workflow_manage.generate_prompt(prompt)},
104+
{'type': 'image_url', 'image_url': {'url': f'data:image/jpeg;base64,{base64_image}'}},
105+
])]
106+
else:
107+
messages = [HumanMessage(self.workflow_manage.generate_prompt(prompt))]
108+
109+
if system is not None and len(system) > 0:
110+
return [
111+
SystemMessage(self.workflow_manage.generate_prompt(system)),
112+
*history_message,
113+
*messages
114+
]
115+
else:
116+
return [
117+
*history_message,
118+
*messages
119+
]
120+
121+
@staticmethod
122+
def reset_message_list(message_list: List[BaseMessage], answer_text):
123+
result = [{'role': 'user' if isinstance(message, HumanMessage) else 'ai', 'content': message.content} for
124+
message
125+
in
126+
message_list]
127+
result.append({'role': 'ai', 'content': answer_text})
128+
return result
129+
130+
def get_details(self, index: int, **kwargs):
131+
return {
132+
'name': self.node.properties.get('stepName'),
133+
"index": index,
134+
'run_time': self.context.get('run_time'),
135+
'system': self.node_params.get('system'),
136+
'history_message': [{'content': message.content, 'role': message.type} for message in
137+
(self.context.get('history_message') if self.context.get(
138+
'history_message') is not None else [])],
139+
'question': self.context.get('question'),
140+
'answer': self.context.get('answer'),
141+
'type': self.node.type,
142+
'message_tokens': self.context.get('message_tokens'),
143+
'answer_tokens': self.context.get('answer_tokens'),
144+
'status': self.status,
145+
'err_message': self.err_message,
146+
'image_list': self.context.get('image_list')
147+
}

apps/application/flow/step_node/start_node/impl/base_start_node.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def execute(self, question, **kwargs) -> NodeResult:
4141
"""
4242
开始节点 初始化全局变量
4343
"""
44-
return NodeResult({'question': question},
44+
return NodeResult({'question': question, 'image': self.workflow_manage.image_list},
4545
workflow_variable)
4646

4747
def get_details(self, index: int, **kwargs):
@@ -61,5 +61,6 @@ def get_details(self, index: int, **kwargs):
6161
'type': self.node.type,
6262
'status': self.status,
6363
'err_message': self.err_message,
64+
'image_list': self.context.get('image'),
6465
'global_fields': global_fields
6566
}

0 commit comments

Comments
 (0)