Skip to content

Commit 1a1e932

Browse files
committed
feat: 工作流表单节点
1 parent 4e615db commit 1a1e932

File tree

53 files changed

+2031
-1015
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

53 files changed

+2031
-1015
lines changed

apps/application/flow/i_step_node.py

Lines changed: 37 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
@desc:
88
"""
99
import time
10+
import uuid
1011
from abc import abstractmethod
1112
from typing import Type, Dict, List
1213

@@ -31,7 +32,7 @@ def write_context(step_variable: Dict, global_variable: Dict, node, workflow):
3132
if workflow.is_result(node, NodeResult(step_variable, global_variable)) and 'answer' in step_variable:
3233
answer = step_variable['answer']
3334
yield answer
34-
workflow.answer += answer
35+
workflow.append_answer(answer)
3536
if global_variable is not None:
3637
for key in global_variable:
3738
workflow.context[key] = global_variable[key]
@@ -54,15 +55,27 @@ def handler(self, chat_id,
5455
'message_tokens' in row and row.get('message_tokens') is not None])
5556
answer_tokens = sum([row.get('answer_tokens') for row in details.values() if
5657
'answer_tokens' in row and row.get('answer_tokens') is not None])
57-
chat_record = ChatRecord(id=chat_record_id,
58-
chat_id=chat_id,
59-
problem_text=question,
60-
answer_text=answer,
61-
details=details,
62-
message_tokens=message_tokens,
63-
answer_tokens=answer_tokens,
64-
run_time=time.time() - workflow.context['start_time'],
65-
index=0)
58+
answer_text_list = workflow.get_answer_text_list()
59+
answer_text = '\n\n'.join(answer_text_list)
60+
if workflow.chat_record is not None:
61+
chat_record = workflow.chat_record
62+
chat_record.answer_text = answer_text
63+
chat_record.details = details
64+
chat_record.message_tokens = message_tokens
65+
chat_record.answer_tokens = answer_tokens
66+
chat_record.answer_text_list = answer_text_list
67+
chat_record.run_time = time.time() - workflow.context['start_time']
68+
else:
69+
chat_record = ChatRecord(id=chat_record_id,
70+
chat_id=chat_id,
71+
problem_text=question,
72+
answer_text=answer_text,
73+
details=details,
74+
message_tokens=message_tokens,
75+
answer_tokens=answer_tokens,
76+
answer_text_list=answer_text_list,
77+
run_time=time.time() - workflow.context['start_time'],
78+
index=0)
6679
self.chat_info.append_chat_record(chat_record, self.client_id)
6780
# 重新设置缓存
6881
chat_cache.set(chat_id,
@@ -118,7 +131,15 @@ class FlowParamsSerializer(serializers.Serializer):
118131

119132

120133
class INode:
121-
def __init__(self, node, workflow_params, workflow_manage):
134+
135+
@abstractmethod
136+
def save_context(self, details, workflow_manage):
137+
pass
138+
139+
def get_answer_text(self):
140+
return self.answer_text
141+
142+
def __init__(self, node, workflow_params, workflow_manage, runtime_node_id=None):
122143
# 当前步骤上下文,用于存储当前步骤信息
123144
self.status = 200
124145
self.err_message = ''
@@ -129,7 +150,12 @@ def __init__(self, node, workflow_params, workflow_manage):
129150
self.node_params_serializer = None
130151
self.flow_params_serializer = None
131152
self.context = {}
153+
self.answer_text = None
132154
self.id = node.id
155+
if runtime_node_id is None:
156+
self.runtime_node_id = str(uuid.uuid1())
157+
else:
158+
self.runtime_node_id = runtime_node_id
133159

134160
def valid_args(self, node_params, flow_params):
135161
flow_params_serializer_class = self.get_flow_params_serializer_class()

apps/application/flow/step_node/__init__.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,19 +9,23 @@
99
from .ai_chat_step_node import *
1010
from .application_node import BaseApplicationNode
1111
from .condition_node import *
12-
from .question_node import *
13-
from .search_dataset_node import *
14-
from .start_node import *
1512
from .direct_reply_node import *
13+
from .form_node import *
1614
from .function_lib_node import *
1715
from .function_node import *
16+
from .question_node import *
1817
from .reranker_node import *
18+
1919
from .document_extract_node import *
2020
from .image_understand_step_node import *
2121

22+
from .search_dataset_node import *
23+
from .start_node import *
24+
2225
node_list = [BaseStartStepNode, BaseChatNode, BaseSearchDatasetNode, BaseQuestionNode, BaseConditionNode, BaseReplyNode,
23-
BaseFunctionNodeNode, BaseFunctionLibNodeNode, BaseRerankerNode, BaseApplicationNode, BaseDocumentExtractNode,
24-
BaseImageUnderstandNode]
26+
BaseFunctionNodeNode, BaseFunctionLibNodeNode, BaseRerankerNode, BaseApplicationNode,
27+
BaseDocumentExtractNode,
28+
BaseImageUnderstandNode, BaseFormNode]
2529

2630

2731
def get_node(node_type):

apps/application/flow/step_node/ai_chat_step_node/impl/base_chat_node.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def _write_context(node_variable: Dict, workflow_variable: Dict, node: INode, wo
3232
node.context['question'] = node_variable['question']
3333
node.context['run_time'] = time.time() - node.context['start_time']
3434
if workflow.is_result(node, NodeResult(node_variable, workflow_variable)):
35-
workflow.answer += answer
35+
node.answer_text = answer
3636

3737

3838
def write_context_stream(node_variable: Dict, workflow_variable: Dict, node: INode, workflow):
@@ -73,6 +73,11 @@ def get_default_model_params_setting(model_id):
7373

7474

7575
class BaseChatNode(IChatNode):
76+
def save_context(self, details, workflow_manage):
77+
self.context['answer'] = details.get('answer')
78+
self.context['question'] = details.get('question')
79+
self.answer_text = details.get('answer')
80+
7681
def execute(self, model_id, system, prompt, dialogue_number, history_chat_record, stream, chat_id, chat_record_id,
7782
model_params_setting=None,
7883
**kwargs) -> NodeResult:

apps/application/flow/step_node/application_node/impl/base_application_node.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def _write_context(node_variable: Dict, workflow_variable: Dict, node: INode, wo
2121
node.context['question'] = node_variable['question']
2222
node.context['run_time'] = time.time() - node.context['start_time']
2323
if workflow.is_result(node, NodeResult(node_variable, workflow_variable)):
24-
workflow.answer += answer
24+
node.answer_text = answer
2525

2626

2727
def write_context_stream(node_variable: Dict, workflow_variable: Dict, node: INode, workflow):
@@ -64,6 +64,12 @@ def write_context(node_variable: Dict, workflow_variable: Dict, node: INode, wor
6464

6565
class BaseApplicationNode(IApplicationNode):
6666

67+
def save_context(self, details, workflow_manage):
68+
self.context['answer'] = details.get('answer')
69+
self.context['question'] = details.get('question')
70+
self.context['type'] = details.get('type')
71+
self.answer_text = details.get('answer')
72+
6773
def execute(self, application_id, message, chat_id, chat_record_id, stream, re_chat, client_id, client_type,
6874
**kwargs) -> NodeResult:
6975
from application.serializers.chat_message_serializers import ChatMessageSerializer

apps/application/flow/step_node/condition_node/impl/base_condition_node.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,10 @@
1414

1515

1616
class BaseConditionNode(IConditionNode):
17+
def save_context(self, details, workflow_manage):
18+
self.context['branch_id'] = details.get('branch_id')
19+
self.context['branch_name'] = details.get('branch_name')
20+
1721
def execute(self, **kwargs) -> NodeResult:
1822
branch_list = self.node_params_serializer.data['branch']
1923
branch = self._execute(branch_list)

apps/application/flow/step_node/direct_reply_node/impl/base_reply_node.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@
1313

1414

1515
class BaseReplyNode(IReplyNode):
16+
def save_context(self, details, workflow_manage):
17+
self.context['answer'] = details.get('answer')
18+
self.answer_text = details.get('answer')
1619
def execute(self, reply_type, stream, fields=None, content=None, **kwargs) -> NodeResult:
1720
if reply_type == 'referencing':
1821
result = self.get_reference_content(fields)
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
# coding=utf-8
2+
"""
3+
@project: MaxKB
4+
@Author:虎
5+
@file: __init__.py.py
6+
@date:2024/11/4 14:48
7+
@desc:
8+
"""
9+
from .impl import *
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
# coding=utf-8
2+
"""
3+
@project: MaxKB
4+
@Author:虎
5+
@file: i_form_node.py
6+
@date:2024/11/4 14:48
7+
@desc:
8+
"""
9+
from typing import Type
10+
11+
from rest_framework import serializers
12+
13+
from application.flow.i_step_node import INode, NodeResult
14+
from common.util.field_message import ErrMessage
15+
16+
17+
class FormNodeParamsSerializer(serializers.Serializer):
18+
form_field_list = serializers.ListField(required=True, error_messages=ErrMessage.list("表单配置"))
19+
form_content_format = serializers.CharField(required=True, error_messages=ErrMessage.char('表单输出内容'))
20+
21+
22+
class IFormNode(INode):
23+
type = 'form-node'
24+
25+
def get_node_params_serializer_class(self) -> Type[serializers.Serializer]:
26+
return FormNodeParamsSerializer
27+
28+
def _run(self):
29+
return self.execute(**self.node_params_serializer.data, **self.flow_params_serializer.data)
30+
31+
def execute(self, form_field_list, form_content_format, **kwargs) -> NodeResult:
32+
pass
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
# coding=utf-8
2+
"""
3+
@project: MaxKB
4+
@Author:虎
5+
@file: __init__.py.py
6+
@date:2024/11/4 14:49
7+
@desc:
8+
"""
9+
from .base_form_node import BaseFormNode
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
# coding=utf-8
2+
"""
3+
@project: MaxKB
4+
@Author:虎
5+
@file: base_form_node.py
6+
@date:2024/11/4 14:52
7+
@desc:
8+
"""
9+
import json
10+
import time
11+
from typing import Dict
12+
13+
from langchain_core.prompts import PromptTemplate
14+
15+
from application.flow.i_step_node import NodeResult
16+
from application.flow.step_node.form_node.i_form_node import IFormNode
17+
18+
19+
def write_context(step_variable: Dict, global_variable: Dict, node, workflow):
20+
if step_variable is not None:
21+
for key in step_variable:
22+
node.context[key] = step_variable[key]
23+
if workflow.is_result(node, NodeResult(step_variable, global_variable)) and 'result' in step_variable:
24+
result = step_variable['result']
25+
yield result
26+
node.answer_text = result
27+
node.context['run_time'] = time.time() - node.context['start_time']
28+
29+
30+
class BaseFormNode(IFormNode):
31+
def save_context(self, details, workflow_manage):
32+
self.context['result'] = details.get('result')
33+
self.context['form_content_format'] = details.get('form_content_format')
34+
self.context['form_field_list'] = details.get('form_field_list')
35+
self.context['run_time'] = details.get('run_time')
36+
self.context['start_time'] = details.get('start_time')
37+
self.answer_text = details.get('result')
38+
39+
def execute(self, form_field_list, form_content_format, **kwargs) -> NodeResult:
40+
form_setting = {"form_field_list": form_field_list, "runtime_node_id": self.runtime_node_id,
41+
"chat_record_id": self.flow_params_serializer.data.get("chat_record_id"),
42+
"is_submit": self.context.get("is_submit", False)}
43+
form = f'<form_rander>{json.dumps(form_setting)}</form_rander>'
44+
prompt_template = PromptTemplate.from_template(form_content_format, template_format='jinja2')
45+
value = prompt_template.format(form=form)
46+
return NodeResult(
47+
{'result': value, 'form_field_list': form_field_list, 'form_content_format': form_content_format}, {},
48+
_write_context=write_context)
49+
50+
def get_answer_text(self):
51+
form_content_format = self.context.get('form_content_format')
52+
form_field_list = self.context.get('form_field_list')
53+
form_setting = {"form_field_list": form_field_list, "runtime_node_id": self.runtime_node_id,
54+
"chat_record_id": self.flow_params_serializer.data.get("chat_record_id"),
55+
'form_data': self.context.get('form_data', {}),
56+
"is_submit": self.context.get("is_submit", False)}
57+
form = f'<form_rander>{json.dumps(form_setting)}</form_rander>'
58+
prompt_template = PromptTemplate.from_template(form_content_format, template_format='jinja2')
59+
value = prompt_template.format(form=form)
60+
return value
61+
62+
def get_details(self, index: int, **kwargs):
63+
form_content_format = self.context.get('form_content_format')
64+
form_field_list = self.context.get('form_field_list')
65+
form_setting = {"form_field_list": form_field_list, "runtime_node_id": self.runtime_node_id,
66+
"chat_record_id": self.flow_params_serializer.data.get("chat_record_id"),
67+
'form_data': self.context.get('form_data', {}),
68+
"is_submit": self.context.get("is_submit", False)}
69+
form = f'<form_rander>{json.dumps(form_setting)}</form_rander>'
70+
prompt_template = PromptTemplate.from_template(form_content_format, template_format='jinja2')
71+
value = prompt_template.format(form=form)
72+
return {
73+
'name': self.node.properties.get('stepName'),
74+
"index": index,
75+
"result": value,
76+
"form_content_format": self.context.get('form_content_format'),
77+
"form_field_list": self.context.get('form_field_list'),
78+
'form_data': self.context.get('form_data'),
79+
'start_time': self.context.get('start_time'),
80+
'run_time': self.context.get('run_time'),
81+
'type': self.node.type,
82+
'status': self.status,
83+
'err_message': self.err_message
84+
}

0 commit comments

Comments
 (0)