From 06872c812c1a7df32073067e3352a62f37e30471 Mon Sep 17 00:00:00 2001 From: shaohuzhang1 Date: Tue, 27 May 2025 18:10:07 +0800 Subject: [PATCH] feat: application flow --- .../chat_pipeline/I_base_chat_pipeline.py | 157 ++++ apps/application/chat_pipeline/__init__.py | 8 + .../chat_pipeline/pipeline_manage.py | 57 ++ .../chat_pipeline/step/__init__.py | 8 + .../chat_pipeline/step/chat_step/__init__.py | 8 + .../step/chat_step/i_chat_step.py | 110 +++ .../step/chat_step/impl/base_chat_step.py | 334 +++++++ .../generate_human_message_step/__init__.py | 8 + .../i_generate_human_message_step.py | 81 ++ .../impl/base_generate_human_message_step.py | 73 ++ .../step/reset_problem_step/__init__.py | 8 + .../i_reset_problem_step.py | 57 ++ .../impl/base_reset_problem_step.py | 68 ++ .../step/search_dataset_step/__init__.py | 8 + .../i_search_dataset_step.py | 77 ++ .../impl/base_search_dataset_step.py | 138 +++ apps/application/flow/__init__.py | 8 + apps/application/flow/common.py | 44 + apps/application/flow/default_workflow.json | 451 ++++++++++ .../application/flow/default_workflow_en.json | 451 ++++++++++ .../application/flow/default_workflow_zh.json | 451 ++++++++++ .../flow/default_workflow_zh_Hant.json | 451 ++++++++++ apps/application/flow/i_step_node.py | 256 ++++++ apps/application/flow/step_node/__init__.py | 42 + .../step_node/ai_chat_step_node/__init__.py | 9 + .../ai_chat_step_node/i_chat_node.py | 58 ++ .../ai_chat_step_node/impl/__init__.py | 9 + .../ai_chat_step_node/impl/base_chat_node.py | 288 ++++++ .../step_node/application_node/__init__.py | 2 + .../application_node/i_application_node.py | 86 ++ .../application_node/impl/__init__.py | 2 + .../impl/base_application_node.py | 267 ++++++ .../flow/step_node/condition_node/__init__.py | 9 + .../condition_node/compare/__init__.py | 30 + .../condition_node/compare/compare.py | 20 + .../condition_node/compare/contain_compare.py | 23 + .../condition_node/compare/equal_compare.py | 21 + .../condition_node/compare/ge_compare.py | 24 + .../condition_node/compare/gt_compare.py | 24 + .../compare/is_not_null_compare.py | 21 + .../condition_node/compare/is_not_true.py | 24 + .../condition_node/compare/is_null_compare.py | 21 + .../condition_node/compare/is_true.py | 24 + .../condition_node/compare/le_compare.py | 24 + .../compare/len_equal_compare.py | 24 + .../condition_node/compare/len_ge_compare.py | 24 + .../condition_node/compare/len_gt_compare.py | 24 + .../condition_node/compare/len_le_compare.py | 24 + .../condition_node/compare/len_lt_compare.py | 24 + .../condition_node/compare/lt_compare.py | 24 + .../compare/not_contain_compare.py | 23 + .../condition_node/i_condition_node.py | 39 + .../step_node/condition_node/impl/__init__.py | 9 + .../impl/base_condition_node.py | 62 ++ .../step_node/direct_reply_node/__init__.py | 9 + .../direct_reply_node/i_reply_node.py | 48 + .../direct_reply_node/impl/__init__.py | 9 + .../direct_reply_node/impl/base_reply_node.py | 45 + .../document_extract_node/__init__.py | 1 + .../i_document_extract_node.py | 28 + .../document_extract_node/impl/__init__.py | 1 + .../impl/base_document_extract_node.py | 94 ++ .../flow/step_node/form_node/__init__.py | 9 + .../flow/step_node/form_node/i_form_node.py | 35 + .../flow/step_node/form_node/impl/__init__.py | 9 + .../form_node/impl/base_form_node.py | 107 +++ .../step_node/function_lib_node/__init__.py | 9 + .../function_lib_node/i_function_lib_node.py | 48 + .../function_lib_node/impl/__init__.py | 9 + .../impl/base_function_lib_node.py | 150 ++++ .../flow/step_node/function_node/__init__.py | 9 + .../function_node/i_function_node.py | 63 ++ .../step_node/function_node/impl/__init__.py | 9 + .../function_node/impl/base_function_node.py | 108 +++ .../image_generate_step_node/__init__.py | 3 + .../i_image_generate_node.py | 45 + .../image_generate_step_node/impl/__init__.py | 3 + .../impl/base_image_generate_node.py | 122 +++ .../image_understand_step_node/__init__.py | 3 + .../i_image_understand_node.py | 46 + .../impl/__init__.py | 3 + .../impl/base_image_understand_node.py | 224 +++++ .../flow/step_node/mcp_node/__init__.py | 3 + .../flow/step_node/mcp_node/i_mcp_node.py | 35 + .../flow/step_node/mcp_node/impl/__init__.py | 3 + .../step_node/mcp_node/impl/base_mcp_node.py | 61 ++ .../flow/step_node/question_node/__init__.py | 9 + .../question_node/i_question_node.py | 42 + .../step_node/question_node/impl/__init__.py | 9 + .../question_node/impl/base_question_node.py | 159 ++++ .../flow/step_node/reranker_node/__init__.py | 9 + .../reranker_node/i_reranker_node.py | 60 ++ .../step_node/reranker_node/impl/__init__.py | 9 + .../reranker_node/impl/base_reranker_node.py | 106 +++ .../step_node/search_dataset_node/__init__.py | 9 + .../i_search_dataset_node.py | 79 ++ .../search_dataset_node/impl/__init__.py | 9 + .../impl/base_search_dataset_node.py | 146 ++++ .../speech_to_text_step_node/__init__.py | 3 + .../i_speech_to_text_node.py | 38 + .../speech_to_text_step_node/impl/__init__.py | 3 + .../impl/base_speech_to_text_node.py | 72 ++ .../flow/step_node/start_node/__init__.py | 9 + .../flow/step_node/start_node/i_start_node.py | 20 + .../step_node/start_node/impl/__init__.py | 9 + .../start_node/impl/base_start_node.py | 92 ++ .../text_to_speech_step_node/__init__.py | 3 + .../i_text_to_speech_node.py | 36 + .../text_to_speech_step_node/impl/__init__.py | 3 + .../impl/base_text_to_speech_node.py | 76 ++ .../variable_assign_node/__init__.py | 3 + .../i_variable_assign_node.py | 27 + .../variable_assign_node/impl/__init__.py | 9 + .../impl/base_variable_assign_node.py | 65 ++ apps/application/flow/tools.py | 191 ++++ apps/application/flow/workflow_manage.py | 827 ++++++++++++++++++ apps/application/migrations/0001_initial.py | 94 +- apps/application/migrations/0002_initial.py | 22 - apps/application/models/application.py | 2 +- apps/application/serializers/application.py | 35 +- .../serializers/application_folder.py | 21 + apps/folders/serializers/folder.py | 5 +- 122 files changed, 8221 insertions(+), 58 deletions(-) create mode 100644 apps/application/chat_pipeline/I_base_chat_pipeline.py create mode 100644 apps/application/chat_pipeline/__init__.py create mode 100644 apps/application/chat_pipeline/pipeline_manage.py create mode 100644 apps/application/chat_pipeline/step/__init__.py create mode 100644 apps/application/chat_pipeline/step/chat_step/__init__.py create mode 100644 apps/application/chat_pipeline/step/chat_step/i_chat_step.py create mode 100644 apps/application/chat_pipeline/step/chat_step/impl/base_chat_step.py create mode 100644 apps/application/chat_pipeline/step/generate_human_message_step/__init__.py create mode 100644 apps/application/chat_pipeline/step/generate_human_message_step/i_generate_human_message_step.py create mode 100644 apps/application/chat_pipeline/step/generate_human_message_step/impl/base_generate_human_message_step.py create mode 100644 apps/application/chat_pipeline/step/reset_problem_step/__init__.py create mode 100644 apps/application/chat_pipeline/step/reset_problem_step/i_reset_problem_step.py create mode 100644 apps/application/chat_pipeline/step/reset_problem_step/impl/base_reset_problem_step.py create mode 100644 apps/application/chat_pipeline/step/search_dataset_step/__init__.py create mode 100644 apps/application/chat_pipeline/step/search_dataset_step/i_search_dataset_step.py create mode 100644 apps/application/chat_pipeline/step/search_dataset_step/impl/base_search_dataset_step.py create mode 100644 apps/application/flow/__init__.py create mode 100644 apps/application/flow/common.py create mode 100644 apps/application/flow/default_workflow.json create mode 100644 apps/application/flow/default_workflow_en.json create mode 100644 apps/application/flow/default_workflow_zh.json create mode 100644 apps/application/flow/default_workflow_zh_Hant.json create mode 100644 apps/application/flow/i_step_node.py create mode 100644 apps/application/flow/step_node/__init__.py create mode 100644 apps/application/flow/step_node/ai_chat_step_node/__init__.py create mode 100644 apps/application/flow/step_node/ai_chat_step_node/i_chat_node.py create mode 100644 apps/application/flow/step_node/ai_chat_step_node/impl/__init__.py create mode 100644 apps/application/flow/step_node/ai_chat_step_node/impl/base_chat_node.py create mode 100644 apps/application/flow/step_node/application_node/__init__.py create mode 100644 apps/application/flow/step_node/application_node/i_application_node.py create mode 100644 apps/application/flow/step_node/application_node/impl/__init__.py create mode 100644 apps/application/flow/step_node/application_node/impl/base_application_node.py create mode 100644 apps/application/flow/step_node/condition_node/__init__.py create mode 100644 apps/application/flow/step_node/condition_node/compare/__init__.py create mode 100644 apps/application/flow/step_node/condition_node/compare/compare.py create mode 100644 apps/application/flow/step_node/condition_node/compare/contain_compare.py create mode 100644 apps/application/flow/step_node/condition_node/compare/equal_compare.py create mode 100644 apps/application/flow/step_node/condition_node/compare/ge_compare.py create mode 100644 apps/application/flow/step_node/condition_node/compare/gt_compare.py create mode 100644 apps/application/flow/step_node/condition_node/compare/is_not_null_compare.py create mode 100644 apps/application/flow/step_node/condition_node/compare/is_not_true.py create mode 100644 apps/application/flow/step_node/condition_node/compare/is_null_compare.py create mode 100644 apps/application/flow/step_node/condition_node/compare/is_true.py create mode 100644 apps/application/flow/step_node/condition_node/compare/le_compare.py create mode 100644 apps/application/flow/step_node/condition_node/compare/len_equal_compare.py create mode 100644 apps/application/flow/step_node/condition_node/compare/len_ge_compare.py create mode 100644 apps/application/flow/step_node/condition_node/compare/len_gt_compare.py create mode 100644 apps/application/flow/step_node/condition_node/compare/len_le_compare.py create mode 100644 apps/application/flow/step_node/condition_node/compare/len_lt_compare.py create mode 100644 apps/application/flow/step_node/condition_node/compare/lt_compare.py create mode 100644 apps/application/flow/step_node/condition_node/compare/not_contain_compare.py create mode 100644 apps/application/flow/step_node/condition_node/i_condition_node.py create mode 100644 apps/application/flow/step_node/condition_node/impl/__init__.py create mode 100644 apps/application/flow/step_node/condition_node/impl/base_condition_node.py create mode 100644 apps/application/flow/step_node/direct_reply_node/__init__.py create mode 100644 apps/application/flow/step_node/direct_reply_node/i_reply_node.py create mode 100644 apps/application/flow/step_node/direct_reply_node/impl/__init__.py create mode 100644 apps/application/flow/step_node/direct_reply_node/impl/base_reply_node.py create mode 100644 apps/application/flow/step_node/document_extract_node/__init__.py create mode 100644 apps/application/flow/step_node/document_extract_node/i_document_extract_node.py create mode 100644 apps/application/flow/step_node/document_extract_node/impl/__init__.py create mode 100644 apps/application/flow/step_node/document_extract_node/impl/base_document_extract_node.py create mode 100644 apps/application/flow/step_node/form_node/__init__.py create mode 100644 apps/application/flow/step_node/form_node/i_form_node.py create mode 100644 apps/application/flow/step_node/form_node/impl/__init__.py create mode 100644 apps/application/flow/step_node/form_node/impl/base_form_node.py create mode 100644 apps/application/flow/step_node/function_lib_node/__init__.py create mode 100644 apps/application/flow/step_node/function_lib_node/i_function_lib_node.py create mode 100644 apps/application/flow/step_node/function_lib_node/impl/__init__.py create mode 100644 apps/application/flow/step_node/function_lib_node/impl/base_function_lib_node.py create mode 100644 apps/application/flow/step_node/function_node/__init__.py create mode 100644 apps/application/flow/step_node/function_node/i_function_node.py create mode 100644 apps/application/flow/step_node/function_node/impl/__init__.py create mode 100644 apps/application/flow/step_node/function_node/impl/base_function_node.py create mode 100644 apps/application/flow/step_node/image_generate_step_node/__init__.py create mode 100644 apps/application/flow/step_node/image_generate_step_node/i_image_generate_node.py create mode 100644 apps/application/flow/step_node/image_generate_step_node/impl/__init__.py create mode 100644 apps/application/flow/step_node/image_generate_step_node/impl/base_image_generate_node.py create mode 100644 apps/application/flow/step_node/image_understand_step_node/__init__.py create mode 100644 apps/application/flow/step_node/image_understand_step_node/i_image_understand_node.py create mode 100644 apps/application/flow/step_node/image_understand_step_node/impl/__init__.py create mode 100644 apps/application/flow/step_node/image_understand_step_node/impl/base_image_understand_node.py create mode 100644 apps/application/flow/step_node/mcp_node/__init__.py create mode 100644 apps/application/flow/step_node/mcp_node/i_mcp_node.py create mode 100644 apps/application/flow/step_node/mcp_node/impl/__init__.py create mode 100644 apps/application/flow/step_node/mcp_node/impl/base_mcp_node.py create mode 100644 apps/application/flow/step_node/question_node/__init__.py create mode 100644 apps/application/flow/step_node/question_node/i_question_node.py create mode 100644 apps/application/flow/step_node/question_node/impl/__init__.py create mode 100644 apps/application/flow/step_node/question_node/impl/base_question_node.py create mode 100644 apps/application/flow/step_node/reranker_node/__init__.py create mode 100644 apps/application/flow/step_node/reranker_node/i_reranker_node.py create mode 100644 apps/application/flow/step_node/reranker_node/impl/__init__.py create mode 100644 apps/application/flow/step_node/reranker_node/impl/base_reranker_node.py create mode 100644 apps/application/flow/step_node/search_dataset_node/__init__.py create mode 100644 apps/application/flow/step_node/search_dataset_node/i_search_dataset_node.py create mode 100644 apps/application/flow/step_node/search_dataset_node/impl/__init__.py create mode 100644 apps/application/flow/step_node/search_dataset_node/impl/base_search_dataset_node.py create mode 100644 apps/application/flow/step_node/speech_to_text_step_node/__init__.py create mode 100644 apps/application/flow/step_node/speech_to_text_step_node/i_speech_to_text_node.py create mode 100644 apps/application/flow/step_node/speech_to_text_step_node/impl/__init__.py create mode 100644 apps/application/flow/step_node/speech_to_text_step_node/impl/base_speech_to_text_node.py create mode 100644 apps/application/flow/step_node/start_node/__init__.py create mode 100644 apps/application/flow/step_node/start_node/i_start_node.py create mode 100644 apps/application/flow/step_node/start_node/impl/__init__.py create mode 100644 apps/application/flow/step_node/start_node/impl/base_start_node.py create mode 100644 apps/application/flow/step_node/text_to_speech_step_node/__init__.py create mode 100644 apps/application/flow/step_node/text_to_speech_step_node/i_text_to_speech_node.py create mode 100644 apps/application/flow/step_node/text_to_speech_step_node/impl/__init__.py create mode 100644 apps/application/flow/step_node/text_to_speech_step_node/impl/base_text_to_speech_node.py create mode 100644 apps/application/flow/step_node/variable_assign_node/__init__.py create mode 100644 apps/application/flow/step_node/variable_assign_node/i_variable_assign_node.py create mode 100644 apps/application/flow/step_node/variable_assign_node/impl/__init__.py create mode 100644 apps/application/flow/step_node/variable_assign_node/impl/base_variable_assign_node.py create mode 100644 apps/application/flow/tools.py create mode 100644 apps/application/flow/workflow_manage.py delete mode 100644 apps/application/migrations/0002_initial.py create mode 100644 apps/application/serializers/application_folder.py diff --git a/apps/application/chat_pipeline/I_base_chat_pipeline.py b/apps/application/chat_pipeline/I_base_chat_pipeline.py new file mode 100644 index 00000000000..a35bdc39c7f --- /dev/null +++ b/apps/application/chat_pipeline/I_base_chat_pipeline.py @@ -0,0 +1,157 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: I_base_chat_pipeline.py + @date:2024/1/9 17:25 + @desc: +""" +import time +from abc import abstractmethod +from typing import Type + +from rest_framework import serializers + +from dataset.models import Paragraph + + +class ParagraphPipelineModel: + + def __init__(self, _id: str, document_id: str, dataset_id: str, content: str, title: str, status: str, + is_active: bool, comprehensive_score: float, similarity: float, dataset_name: str, document_name: str, + hit_handling_method: str, directly_return_similarity: float, meta: dict = None): + self.id = _id + self.document_id = document_id + self.dataset_id = dataset_id + self.content = content + self.title = title + self.status = status, + self.is_active = is_active + self.comprehensive_score = comprehensive_score + self.similarity = similarity + self.dataset_name = dataset_name + self.document_name = document_name + self.hit_handling_method = hit_handling_method + self.directly_return_similarity = directly_return_similarity + self.meta = meta + + def to_dict(self): + return { + 'id': self.id, + 'document_id': self.document_id, + 'dataset_id': self.dataset_id, + 'content': self.content, + 'title': self.title, + 'status': self.status, + 'is_active': self.is_active, + 'comprehensive_score': self.comprehensive_score, + 'similarity': self.similarity, + 'dataset_name': self.dataset_name, + 'document_name': self.document_name, + 'meta': self.meta, + } + + class builder: + def __init__(self): + self.similarity = None + self.paragraph = {} + self.comprehensive_score = None + self.document_name = None + self.dataset_name = None + self.hit_handling_method = None + self.directly_return_similarity = 0.9 + self.meta = {} + + def add_paragraph(self, paragraph): + if isinstance(paragraph, Paragraph): + self.paragraph = {'id': paragraph.id, + 'document_id': paragraph.document_id, + 'dataset_id': paragraph.dataset_id, + 'content': paragraph.content, + 'title': paragraph.title, + 'status': paragraph.status, + 'is_active': paragraph.is_active, + } + else: + self.paragraph = paragraph + return self + + def add_dataset_name(self, dataset_name): + self.dataset_name = dataset_name + return self + + def add_document_name(self, document_name): + self.document_name = document_name + return self + + def add_hit_handling_method(self, hit_handling_method): + self.hit_handling_method = hit_handling_method + return self + + def add_directly_return_similarity(self, directly_return_similarity): + self.directly_return_similarity = directly_return_similarity + return self + + def add_comprehensive_score(self, comprehensive_score: float): + self.comprehensive_score = comprehensive_score + return self + + def add_similarity(self, similarity: float): + self.similarity = similarity + return self + + def add_meta(self, meta: dict): + self.meta = meta + return self + + def build(self): + return ParagraphPipelineModel(str(self.paragraph.get('id')), str(self.paragraph.get('document_id')), + str(self.paragraph.get('dataset_id')), + self.paragraph.get('content'), self.paragraph.get('title'), + self.paragraph.get('status'), + self.paragraph.get('is_active'), + self.comprehensive_score, self.similarity, self.dataset_name, + self.document_name, self.hit_handling_method, self.directly_return_similarity, + self.meta) + + +class IBaseChatPipelineStep: + def __init__(self): + # 当前步骤上下文,用于存储当前步骤信息 + self.context = {} + + @abstractmethod + def get_step_serializer(self, manage) -> Type[serializers.Serializer]: + pass + + def valid_args(self, manage): + step_serializer_clazz = self.get_step_serializer(manage) + step_serializer = step_serializer_clazz(data=manage.context) + step_serializer.is_valid(raise_exception=True) + self.context['step_args'] = step_serializer.data + + def run(self, manage): + """ + + :param manage: 步骤管理器 + :return: 执行结果 + """ + start_time = time.time() + self.context['start_time'] = start_time + # 校验参数, + self.valid_args(manage) + self._run(manage) + self.context['run_time'] = time.time() - start_time + + def _run(self, manage): + pass + + def execute(self, **kwargs): + pass + + def get_details(self, manage, **kwargs): + """ + 运行详情 + :return: 步骤详情 + """ + return None diff --git a/apps/application/chat_pipeline/__init__.py b/apps/application/chat_pipeline/__init__.py new file mode 100644 index 00000000000..719a7e29c90 --- /dev/null +++ b/apps/application/chat_pipeline/__init__.py @@ -0,0 +1,8 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: __init__.py.py + @date:2024/1/9 17:23 + @desc: +""" diff --git a/apps/application/chat_pipeline/pipeline_manage.py b/apps/application/chat_pipeline/pipeline_manage.py new file mode 100644 index 00000000000..7c4acb3a34a --- /dev/null +++ b/apps/application/chat_pipeline/pipeline_manage.py @@ -0,0 +1,57 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: pipeline_manage.py + @date:2024/1/9 17:40 + @desc: +""" +import time +from functools import reduce +from typing import List, Type, Dict + +from application.chat_pipeline.I_base_chat_pipeline import IBaseChatPipelineStep +from common.handle.base_to_response import BaseToResponse +from common.handle.impl.response.system_to_response import SystemToResponse + + +class PipelineManage: + def __init__(self, step_list: List[Type[IBaseChatPipelineStep]], + base_to_response: BaseToResponse = SystemToResponse()): + # 步骤执行器 + self.step_list = [step() for step in step_list] + # 上下文 + self.context = {'message_tokens': 0, 'answer_tokens': 0} + self.base_to_response = base_to_response + + def run(self, context: Dict = None): + self.context['start_time'] = time.time() + if context is not None: + for key, value in context.items(): + self.context[key] = value + for step in self.step_list: + step.run(self) + + def get_details(self): + return reduce(lambda x, y: {**x, **y}, [{item.get('step_type'): item} for item in + filter(lambda r: r is not None, + [row.get_details(self) for row in self.step_list])], {}) + + def get_base_to_response(self): + return self.base_to_response + + class builder: + def __init__(self): + self.step_list: List[Type[IBaseChatPipelineStep]] = [] + self.base_to_response = SystemToResponse() + + def append_step(self, step: Type[IBaseChatPipelineStep]): + self.step_list.append(step) + return self + + def add_base_to_response(self, base_to_response: BaseToResponse): + self.base_to_response = base_to_response + return self + + def build(self): + return PipelineManage(step_list=self.step_list, base_to_response=self.base_to_response) diff --git a/apps/application/chat_pipeline/step/__init__.py b/apps/application/chat_pipeline/step/__init__.py new file mode 100644 index 00000000000..5d9549cdc64 --- /dev/null +++ b/apps/application/chat_pipeline/step/__init__.py @@ -0,0 +1,8 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: __init__.py.py + @date:2024/1/9 18:23 + @desc: +""" diff --git a/apps/application/chat_pipeline/step/chat_step/__init__.py b/apps/application/chat_pipeline/step/chat_step/__init__.py new file mode 100644 index 00000000000..5d9549cdc64 --- /dev/null +++ b/apps/application/chat_pipeline/step/chat_step/__init__.py @@ -0,0 +1,8 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: __init__.py.py + @date:2024/1/9 18:23 + @desc: +""" diff --git a/apps/application/chat_pipeline/step/chat_step/i_chat_step.py b/apps/application/chat_pipeline/step/chat_step/i_chat_step.py new file mode 100644 index 00000000000..2673c6b7bbd --- /dev/null +++ b/apps/application/chat_pipeline/step/chat_step/i_chat_step.py @@ -0,0 +1,110 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: i_chat_step.py + @date:2024/1/9 18:17 + @desc: 对话 +""" +from abc import abstractmethod +from typing import Type, List + +from django.utils.translation import gettext_lazy as _ +from langchain.chat_models.base import BaseChatModel +from langchain.schema import BaseMessage +from rest_framework import serializers + +from application.chat_pipeline.I_base_chat_pipeline import IBaseChatPipelineStep, ParagraphPipelineModel +from application.chat_pipeline.pipeline_manage import PipelineManage +from application.serializers.application_serializers import NoReferencesSetting +from common.field.common import InstanceField +from common.util.field_message import ErrMessage + + +class ModelField(serializers.Field): + def to_internal_value(self, data): + if not isinstance(data, BaseChatModel): + self.fail(_('Model type error'), value=data) + return data + + def to_representation(self, value): + return value + + +class MessageField(serializers.Field): + def to_internal_value(self, data): + if not isinstance(data, BaseMessage): + self.fail(_('Message type error'), value=data) + return data + + def to_representation(self, value): + return value + + +class PostResponseHandler: + @abstractmethod + def handler(self, chat_id, chat_record_id, paragraph_list: List[ParagraphPipelineModel], problem_text: str, + answer_text, + manage, step, padding_problem_text: str = None, client_id=None, **kwargs): + pass + + +class IChatStep(IBaseChatPipelineStep): + class InstanceSerializer(serializers.Serializer): + # 对话列表 + message_list = serializers.ListField(required=True, child=MessageField(required=True), + error_messages=ErrMessage.list(_("Conversation list"))) + model_id = serializers.UUIDField(required=False, allow_null=True, error_messages=ErrMessage.uuid(_("Model id"))) + # 段落列表 + paragraph_list = serializers.ListField(error_messages=ErrMessage.list(_("Paragraph List"))) + # 对话id + chat_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid(_("Conversation ID"))) + # 用户问题 + problem_text = serializers.CharField(required=True, error_messages=ErrMessage.uuid(_("User Questions"))) + # 后置处理器 + post_response_handler = InstanceField(model_type=PostResponseHandler, + error_messages=ErrMessage.base(_("Post-processor"))) + # 补全问题 + padding_problem_text = serializers.CharField(required=False, + error_messages=ErrMessage.base(_("Completion Question"))) + # 是否使用流的形式输出 + stream = serializers.BooleanField(required=False, error_messages=ErrMessage.base(_("Streaming Output"))) + client_id = serializers.CharField(required=True, error_messages=ErrMessage.char(_("Client id"))) + client_type = serializers.CharField(required=True, error_messages=ErrMessage.char(_("Client Type"))) + # 未查询到引用分段 + no_references_setting = NoReferencesSetting(required=True, + error_messages=ErrMessage.base(_("No reference segment settings"))) + + user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid(_("User ID"))) + + model_setting = serializers.DictField(required=True, allow_null=True, + error_messages=ErrMessage.dict(_("Model settings"))) + + model_params_setting = serializers.DictField(required=False, allow_null=True, + error_messages=ErrMessage.dict(_("Model parameter settings"))) + + def is_valid(self, *, raise_exception=False): + super().is_valid(raise_exception=True) + message_list: List = self.initial_data.get('message_list') + for message in message_list: + if not isinstance(message, BaseMessage): + raise Exception(_("message type error")) + + def get_step_serializer(self, manage: PipelineManage) -> Type[serializers.Serializer]: + return self.InstanceSerializer + + def _run(self, manage: PipelineManage): + chat_result = self.execute(**self.context['step_args'], manage=manage) + manage.context['chat_result'] = chat_result + + @abstractmethod + def execute(self, message_list: List[BaseMessage], + chat_id, problem_text, + post_response_handler: PostResponseHandler, + model_id: str = None, + user_id: str = None, + paragraph_list=None, + manage: PipelineManage = None, + padding_problem_text: str = None, stream: bool = True, client_id=None, client_type=None, + no_references_setting=None, model_params_setting=None, model_setting=None, **kwargs): + pass diff --git a/apps/application/chat_pipeline/step/chat_step/impl/base_chat_step.py b/apps/application/chat_pipeline/step/chat_step/impl/base_chat_step.py new file mode 100644 index 00000000000..b03f06d80ad --- /dev/null +++ b/apps/application/chat_pipeline/step/chat_step/impl/base_chat_step.py @@ -0,0 +1,334 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: base_chat_step.py + @date:2024/1/9 18:25 + @desc: 对话step Base实现 +""" +import logging +import time +import traceback +import uuid +from typing import List + +from django.db.models import QuerySet +from django.http import StreamingHttpResponse +from django.utils.translation import gettext as _ +from langchain.chat_models.base import BaseChatModel +from langchain.schema import BaseMessage +from langchain.schema.messages import HumanMessage, AIMessage +from langchain_core.messages import AIMessageChunk +from rest_framework import status + +from application.chat_pipeline.I_base_chat_pipeline import ParagraphPipelineModel +from application.chat_pipeline.pipeline_manage import PipelineManage +from application.chat_pipeline.step.chat_step.i_chat_step import IChatStep, PostResponseHandler +from application.flow.tools import Reasoning +from application.models.api_key_model import ApplicationPublicAccessClient +from common.constants.authentication_type import AuthenticationType +from setting.models_provider.tools import get_model_instance_by_model_user_id + + +def add_access_num(client_id=None, client_type=None, application_id=None): + if client_type == AuthenticationType.APPLICATION_ACCESS_TOKEN.value and application_id is not None: + application_public_access_client = (QuerySet(ApplicationPublicAccessClient).filter(client_id=client_id, + application_id=application_id) + .first()) + if application_public_access_client is not None: + application_public_access_client.access_num = application_public_access_client.access_num + 1 + application_public_access_client.intraday_access_num = application_public_access_client.intraday_access_num + 1 + application_public_access_client.save() + + +def write_context(step, manage, request_token, response_token, all_text): + step.context['message_tokens'] = request_token + step.context['answer_tokens'] = response_token + current_time = time.time() + step.context['answer_text'] = all_text + step.context['run_time'] = current_time - step.context['start_time'] + manage.context['run_time'] = current_time - manage.context['start_time'] + manage.context['message_tokens'] = manage.context['message_tokens'] + request_token + manage.context['answer_tokens'] = manage.context['answer_tokens'] + response_token + + +def event_content(response, + chat_id, + chat_record_id, + paragraph_list: List[ParagraphPipelineModel], + post_response_handler: PostResponseHandler, + manage, + step, + chat_model, + message_list: List[BaseMessage], + problem_text: str, + padding_problem_text: str = None, + client_id=None, client_type=None, + is_ai_chat: bool = None, + model_setting=None): + if model_setting is None: + model_setting = {} + reasoning_content_enable = model_setting.get('reasoning_content_enable', False) + reasoning_content_start = model_setting.get('reasoning_content_start', '') + reasoning_content_end = model_setting.get('reasoning_content_end', '') + reasoning = Reasoning(reasoning_content_start, + reasoning_content_end) + all_text = '' + reasoning_content = '' + try: + response_reasoning_content = False + for chunk in response: + reasoning_chunk = reasoning.get_reasoning_content(chunk) + content_chunk = reasoning_chunk.get('content') + if 'reasoning_content' in chunk.additional_kwargs: + response_reasoning_content = True + reasoning_content_chunk = chunk.additional_kwargs.get('reasoning_content', '') + else: + reasoning_content_chunk = reasoning_chunk.get('reasoning_content') + all_text += content_chunk + if reasoning_content_chunk is None: + reasoning_content_chunk = '' + reasoning_content += reasoning_content_chunk + yield manage.get_base_to_response().to_stream_chunk_response(chat_id, str(chat_record_id), 'ai-chat-node', + [], content_chunk, + False, + 0, 0, {'node_is_end': False, + 'view_type': 'many_view', + 'node_type': 'ai-chat-node', + 'real_node_id': 'ai-chat-node', + 'reasoning_content': reasoning_content_chunk if reasoning_content_enable else ''}) + reasoning_chunk = reasoning.get_end_reasoning_content() + all_text += reasoning_chunk.get('content') + reasoning_content_chunk = "" + if not response_reasoning_content: + reasoning_content_chunk = reasoning_chunk.get( + 'reasoning_content') + yield manage.get_base_to_response().to_stream_chunk_response(chat_id, str(chat_record_id), 'ai-chat-node', + [], reasoning_chunk.get('content'), + False, + 0, 0, {'node_is_end': False, + 'view_type': 'many_view', + 'node_type': 'ai-chat-node', + 'real_node_id': 'ai-chat-node', + 'reasoning_content' + : reasoning_content_chunk if reasoning_content_enable else ''}) + # 获取token + if is_ai_chat: + try: + request_token = chat_model.get_num_tokens_from_messages(message_list) + response_token = chat_model.get_num_tokens(all_text) + except Exception as e: + request_token = 0 + response_token = 0 + else: + request_token = 0 + response_token = 0 + write_context(step, manage, request_token, response_token, all_text) + asker = manage.context.get('form_data', {}).get('asker', None) + post_response_handler.handler(chat_id, chat_record_id, paragraph_list, problem_text, + all_text, manage, step, padding_problem_text, client_id, + reasoning_content=reasoning_content if reasoning_content_enable else '' + , asker=asker) + yield manage.get_base_to_response().to_stream_chunk_response(chat_id, str(chat_record_id), 'ai-chat-node', + [], '', True, + request_token, response_token, + {'node_is_end': True, 'view_type': 'many_view', + 'node_type': 'ai-chat-node'}) + add_access_num(client_id, client_type, manage.context.get('application_id')) + except Exception as e: + logging.getLogger("max_kb_error").error(f'{str(e)}:{traceback.format_exc()}') + all_text = 'Exception:' + str(e) + write_context(step, manage, 0, 0, all_text) + asker = manage.context.get('form_data', {}).get('asker', None) + post_response_handler.handler(chat_id, chat_record_id, paragraph_list, problem_text, + all_text, manage, step, padding_problem_text, client_id, reasoning_content='', + asker=asker) + add_access_num(client_id, client_type, manage.context.get('application_id')) + yield manage.get_base_to_response().to_stream_chunk_response(chat_id, str(chat_record_id), 'ai-chat-node', + [], all_text, + False, + 0, 0, {'node_is_end': False, + 'view_type': 'many_view', + 'node_type': 'ai-chat-node', + 'real_node_id': 'ai-chat-node', + 'reasoning_content': ''}) + + +class BaseChatStep(IChatStep): + def execute(self, message_list: List[BaseMessage], + chat_id, + problem_text, + post_response_handler: PostResponseHandler, + model_id: str = None, + user_id: str = None, + paragraph_list=None, + manage: PipelineManage = None, + padding_problem_text: str = None, + stream: bool = True, + client_id=None, client_type=None, + no_references_setting=None, + model_params_setting=None, + model_setting=None, + **kwargs): + chat_model = get_model_instance_by_model_user_id(model_id, user_id, + **model_params_setting) if model_id is not None else None + if stream: + return self.execute_stream(message_list, chat_id, problem_text, post_response_handler, chat_model, + paragraph_list, + manage, padding_problem_text, client_id, client_type, no_references_setting, + model_setting) + else: + return self.execute_block(message_list, chat_id, problem_text, post_response_handler, chat_model, + paragraph_list, + manage, padding_problem_text, client_id, client_type, no_references_setting, + model_setting) + + def get_details(self, manage, **kwargs): + return { + 'step_type': 'chat_step', + 'run_time': self.context['run_time'], + 'model_id': str(manage.context['model_id']), + 'message_list': self.reset_message_list(self.context['step_args'].get('message_list'), + self.context['answer_text']), + 'message_tokens': self.context['message_tokens'], + 'answer_tokens': self.context['answer_tokens'], + 'cost': 0, + } + + @staticmethod + def reset_message_list(message_list: List[BaseMessage], answer_text): + result = [{'role': 'user' if isinstance(message, HumanMessage) else 'ai', 'content': message.content} for + message + in + message_list] + result.append({'role': 'ai', 'content': answer_text}) + return result + + @staticmethod + def get_stream_result(message_list: List[BaseMessage], + chat_model: BaseChatModel = None, + paragraph_list=None, + no_references_setting=None, + problem_text=None): + if paragraph_list is None: + paragraph_list = [] + directly_return_chunk_list = [AIMessageChunk(content=paragraph.content) + for paragraph in paragraph_list if ( + paragraph.hit_handling_method == 'directly_return' and paragraph.similarity >= paragraph.directly_return_similarity)] + if directly_return_chunk_list is not None and len(directly_return_chunk_list) > 0: + return iter(directly_return_chunk_list), False + elif len(paragraph_list) == 0 and no_references_setting.get( + 'status') == 'designated_answer': + return iter( + [AIMessageChunk(content=no_references_setting.get('value').replace('{question}', problem_text))]), False + if chat_model is None: + return iter([AIMessageChunk( + _('Sorry, the AI model is not configured. Please go to the application to set up the AI model first.'))]), False + else: + return chat_model.stream(message_list), True + + def execute_stream(self, message_list: List[BaseMessage], + chat_id, + problem_text, + post_response_handler: PostResponseHandler, + chat_model: BaseChatModel = None, + paragraph_list=None, + manage: PipelineManage = None, + padding_problem_text: str = None, + client_id=None, client_type=None, + no_references_setting=None, + model_setting=None): + chat_result, is_ai_chat = self.get_stream_result(message_list, chat_model, paragraph_list, + no_references_setting, problem_text) + chat_record_id = uuid.uuid1() + r = StreamingHttpResponse( + streaming_content=event_content(chat_result, chat_id, chat_record_id, paragraph_list, + post_response_handler, manage, self, chat_model, message_list, problem_text, + padding_problem_text, client_id, client_type, is_ai_chat, model_setting), + content_type='text/event-stream;charset=utf-8') + + r['Cache-Control'] = 'no-cache' + return r + + @staticmethod + def get_block_result(message_list: List[BaseMessage], + chat_model: BaseChatModel = None, + paragraph_list=None, + no_references_setting=None, + problem_text=None): + if paragraph_list is None: + paragraph_list = [] + directly_return_chunk_list = [AIMessageChunk(content=paragraph.content) + for paragraph in paragraph_list if ( + paragraph.hit_handling_method == 'directly_return' and paragraph.similarity >= paragraph.directly_return_similarity)] + if directly_return_chunk_list is not None and len(directly_return_chunk_list) > 0: + return directly_return_chunk_list[0], False + elif len(paragraph_list) == 0 and no_references_setting.get( + 'status') == 'designated_answer': + return AIMessage(no_references_setting.get('value').replace('{question}', problem_text)), False + if chat_model is None: + return AIMessage( + _('Sorry, the AI model is not configured. Please go to the application to set up the AI model first.')), False + else: + return chat_model.invoke(message_list), True + + def execute_block(self, message_list: List[BaseMessage], + chat_id, + problem_text, + post_response_handler: PostResponseHandler, + chat_model: BaseChatModel = None, + paragraph_list=None, + manage: PipelineManage = None, + padding_problem_text: str = None, + client_id=None, client_type=None, no_references_setting=None, + model_setting=None): + reasoning_content_enable = model_setting.get('reasoning_content_enable', False) + reasoning_content_start = model_setting.get('reasoning_content_start', '') + reasoning_content_end = model_setting.get('reasoning_content_end', '') + reasoning = Reasoning(reasoning_content_start, + reasoning_content_end) + chat_record_id = uuid.uuid1() + # 调用模型 + try: + chat_result, is_ai_chat = self.get_block_result(message_list, chat_model, paragraph_list, + no_references_setting, problem_text) + if is_ai_chat: + request_token = chat_model.get_num_tokens_from_messages(message_list) + response_token = chat_model.get_num_tokens(chat_result.content) + else: + request_token = 0 + response_token = 0 + write_context(self, manage, request_token, response_token, chat_result.content) + reasoning_result = reasoning.get_reasoning_content(chat_result) + reasoning_result_end = reasoning.get_end_reasoning_content() + content = reasoning_result.get('content') + reasoning_result_end.get('content') + if 'reasoning_content' in chat_result.response_metadata: + reasoning_content = chat_result.response_metadata.get('reasoning_content', '') + else: + reasoning_content = reasoning_result.get('reasoning_content') + reasoning_result_end.get( + 'reasoning_content') + asker = manage.context.get('form_data', {}).get('asker', None) + post_response_handler.handler(chat_id, chat_record_id, paragraph_list, problem_text, + content, manage, self, padding_problem_text, client_id, + reasoning_content=reasoning_content if reasoning_content_enable else '', + asker=asker) + add_access_num(client_id, client_type, manage.context.get('application_id')) + return manage.get_base_to_response().to_block_response(str(chat_id), str(chat_record_id), + content, True, + request_token, response_token, + { + 'reasoning_content': reasoning_content if reasoning_content_enable else '', + 'answer_list': [{ + 'content': content, + 'reasoning_content': reasoning_content if reasoning_content_enable else '' + }]}) + except Exception as e: + all_text = 'Exception:' + str(e) + write_context(self, manage, 0, 0, all_text) + asker = manage.context.get('form_data', {}).get('asker', None) + post_response_handler.handler(chat_id, chat_record_id, paragraph_list, problem_text, + all_text, manage, self, padding_problem_text, client_id, reasoning_content='', + asker=asker) + add_access_num(client_id, client_type, manage.context.get('application_id')) + return manage.get_base_to_response().to_block_response(str(chat_id), str(chat_record_id), all_text, True, 0, + 0, _status=status.HTTP_500_INTERNAL_SERVER_ERROR) diff --git a/apps/application/chat_pipeline/step/generate_human_message_step/__init__.py b/apps/application/chat_pipeline/step/generate_human_message_step/__init__.py new file mode 100644 index 00000000000..5d9549cdc64 --- /dev/null +++ b/apps/application/chat_pipeline/step/generate_human_message_step/__init__.py @@ -0,0 +1,8 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: __init__.py.py + @date:2024/1/9 18:23 + @desc: +""" diff --git a/apps/application/chat_pipeline/step/generate_human_message_step/i_generate_human_message_step.py b/apps/application/chat_pipeline/step/generate_human_message_step/i_generate_human_message_step.py new file mode 100644 index 00000000000..9e23f2d6c52 --- /dev/null +++ b/apps/application/chat_pipeline/step/generate_human_message_step/i_generate_human_message_step.py @@ -0,0 +1,81 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: i_generate_human_message_step.py + @date:2024/1/9 18:15 + @desc: 生成对话模板 +""" +from abc import abstractmethod +from typing import Type, List + +from django.utils.translation import gettext_lazy as _ +from langchain.schema import BaseMessage +from rest_framework import serializers + +from application.chat_pipeline.I_base_chat_pipeline import IBaseChatPipelineStep, ParagraphPipelineModel +from application.chat_pipeline.pipeline_manage import PipelineManage +from application.models import ChatRecord +from application.serializers.application_serializers import NoReferencesSetting +from common.field.common import InstanceField +from common.util.field_message import ErrMessage + + +class IGenerateHumanMessageStep(IBaseChatPipelineStep): + class InstanceSerializer(serializers.Serializer): + # 问题 + problem_text = serializers.CharField(required=True, error_messages=ErrMessage.char(_("question"))) + # 段落列表 + paragraph_list = serializers.ListField(child=InstanceField(model_type=ParagraphPipelineModel, required=True), + error_messages=ErrMessage.list(_("Paragraph List"))) + # 历史对答 + history_chat_record = serializers.ListField(child=InstanceField(model_type=ChatRecord, required=True), + error_messages=ErrMessage.list(_("History Questions"))) + # 多轮对话数量 + dialogue_number = serializers.IntegerField(required=True, error_messages=ErrMessage.integer(_("Number of multi-round conversations"))) + # 最大携带知识库段落长度 + max_paragraph_char_number = serializers.IntegerField(required=True, error_messages=ErrMessage.integer( + _("Maximum length of the knowledge base paragraph"))) + # 模板 + prompt = serializers.CharField(required=True, error_messages=ErrMessage.char(_("Prompt word"))) + system = serializers.CharField(required=False, allow_null=True, allow_blank=True, + error_messages=ErrMessage.char(_("System prompt words (role)"))) + # 补齐问题 + padding_problem_text = serializers.CharField(required=False, error_messages=ErrMessage.char(_("Completion problem"))) + # 未查询到引用分段 + no_references_setting = NoReferencesSetting(required=True, error_messages=ErrMessage.base(_("No reference segment settings"))) + + def get_step_serializer(self, manage: PipelineManage) -> Type[serializers.Serializer]: + return self.InstanceSerializer + + def _run(self, manage: PipelineManage): + message_list = self.execute(**self.context['step_args']) + manage.context['message_list'] = message_list + + @abstractmethod + def execute(self, + problem_text: str, + paragraph_list: List[ParagraphPipelineModel], + history_chat_record: List[ChatRecord], + dialogue_number: int, + max_paragraph_char_number: int, + prompt: str, + padding_problem_text: str = None, + no_references_setting=None, + system=None, + **kwargs) -> List[BaseMessage]: + """ + + :param problem_text: 原始问题文本 + :param paragraph_list: 段落列表 + :param history_chat_record: 历史对话记录 + :param dialogue_number: 多轮对话数量 + :param max_paragraph_char_number: 最大段落长度 + :param prompt: 模板 + :param padding_problem_text 用户修改文本 + :param kwargs: 其他参数 + :param no_references_setting: 无引用分段设置 + :param system 系统提示称 + :return: + """ + pass diff --git a/apps/application/chat_pipeline/step/generate_human_message_step/impl/base_generate_human_message_step.py b/apps/application/chat_pipeline/step/generate_human_message_step/impl/base_generate_human_message_step.py new file mode 100644 index 00000000000..68cfbbcb95d --- /dev/null +++ b/apps/application/chat_pipeline/step/generate_human_message_step/impl/base_generate_human_message_step.py @@ -0,0 +1,73 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: base_generate_human_message_step.py.py + @date:2024/1/10 17:50 + @desc: +""" +from typing import List, Dict + +from langchain.schema import BaseMessage, HumanMessage +from langchain_core.messages import SystemMessage + +from application.chat_pipeline.I_base_chat_pipeline import ParagraphPipelineModel +from application.chat_pipeline.step.generate_human_message_step.i_generate_human_message_step import \ + IGenerateHumanMessageStep +from application.models import ChatRecord +from common.util.split_model import flat_map + + +class BaseGenerateHumanMessageStep(IGenerateHumanMessageStep): + + def execute(self, problem_text: str, + paragraph_list: List[ParagraphPipelineModel], + history_chat_record: List[ChatRecord], + dialogue_number: int, + max_paragraph_char_number: int, + prompt: str, + padding_problem_text: str = None, + no_references_setting=None, + system=None, + **kwargs) -> List[BaseMessage]: + prompt = prompt if (paragraph_list is not None and len(paragraph_list) > 0) else no_references_setting.get( + 'value') + exec_problem_text = padding_problem_text if padding_problem_text is not None else problem_text + start_index = len(history_chat_record) - dialogue_number + history_message = [[history_chat_record[index].get_human_message(), history_chat_record[index].get_ai_message()] + for index in + range(start_index if start_index > 0 else 0, len(history_chat_record))] + if system is not None and len(system) > 0: + return [SystemMessage(system), *flat_map(history_message), + self.to_human_message(prompt, exec_problem_text, max_paragraph_char_number, paragraph_list, + no_references_setting)] + + return [*flat_map(history_message), + self.to_human_message(prompt, exec_problem_text, max_paragraph_char_number, paragraph_list, + no_references_setting)] + + @staticmethod + def to_human_message(prompt: str, + problem: str, + max_paragraph_char_number: int, + paragraph_list: List[ParagraphPipelineModel], + no_references_setting: Dict): + if paragraph_list is None or len(paragraph_list) == 0: + if no_references_setting.get('status') == 'ai_questioning': + return HumanMessage( + content=no_references_setting.get('value').replace('{question}', problem)) + else: + return HumanMessage(content=prompt.replace('{data}', "").replace('{question}', problem)) + temp_data = "" + data_list = [] + for p in paragraph_list: + content = f"{p.title}:{p.content}" + temp_data += content + if len(temp_data) > max_paragraph_char_number: + row_data = content[0:max_paragraph_char_number - len(temp_data)] + data_list.append(f"{row_data}") + break + else: + data_list.append(f"{content}") + data = "\n".join(data_list) + return HumanMessage(content=prompt.replace('{data}', data).replace('{question}', problem)) diff --git a/apps/application/chat_pipeline/step/reset_problem_step/__init__.py b/apps/application/chat_pipeline/step/reset_problem_step/__init__.py new file mode 100644 index 00000000000..5d9549cdc64 --- /dev/null +++ b/apps/application/chat_pipeline/step/reset_problem_step/__init__.py @@ -0,0 +1,8 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: __init__.py.py + @date:2024/1/9 18:23 + @desc: +""" diff --git a/apps/application/chat_pipeline/step/reset_problem_step/i_reset_problem_step.py b/apps/application/chat_pipeline/step/reset_problem_step/i_reset_problem_step.py new file mode 100644 index 00000000000..f48f5c804fd --- /dev/null +++ b/apps/application/chat_pipeline/step/reset_problem_step/i_reset_problem_step.py @@ -0,0 +1,57 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: i_reset_problem_step.py + @date:2024/1/9 18:12 + @desc: 重写处理问题 +""" +from abc import abstractmethod +from typing import Type, List + +from django.utils.translation import gettext_lazy as _ +from rest_framework import serializers + +from application.chat_pipeline.I_base_chat_pipeline import IBaseChatPipelineStep +from application.chat_pipeline.pipeline_manage import PipelineManage +from application.models import ChatRecord +from common.field.common import InstanceField +from common.util.field_message import ErrMessage + + +class IResetProblemStep(IBaseChatPipelineStep): + class InstanceSerializer(serializers.Serializer): + # 问题文本 + problem_text = serializers.CharField(required=True, error_messages=ErrMessage.float(_("question"))) + # 历史对答 + history_chat_record = serializers.ListField(child=InstanceField(model_type=ChatRecord, required=True), + error_messages=ErrMessage.list(_("History Questions"))) + # 大语言模型 + model_id = serializers.UUIDField(required=False, allow_null=True, error_messages=ErrMessage.uuid(_("Model id"))) + user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid(_("User ID"))) + problem_optimization_prompt = serializers.CharField(required=False, max_length=102400, + error_messages=ErrMessage.char( + _("Question completion prompt"))) + + def get_step_serializer(self, manage: PipelineManage) -> Type[serializers.Serializer]: + return self.InstanceSerializer + + def _run(self, manage: PipelineManage): + padding_problem = self.execute(**self.context.get('step_args')) + # 用户输入问题 + source_problem_text = self.context.get('step_args').get('problem_text') + self.context['problem_text'] = source_problem_text + self.context['padding_problem_text'] = padding_problem + manage.context['problem_text'] = source_problem_text + manage.context['padding_problem_text'] = padding_problem + # 累加tokens + manage.context['message_tokens'] = manage.context.get('message_tokens', 0) + self.context.get('message_tokens', + 0) + manage.context['answer_tokens'] = manage.context.get('answer_tokens', 0) + self.context.get('answer_tokens', 0) + + @abstractmethod + def execute(self, problem_text: str, history_chat_record: List[ChatRecord] = None, model_id: str = None, + problem_optimization_prompt=None, + user_id=None, + **kwargs): + pass diff --git a/apps/application/chat_pipeline/step/reset_problem_step/impl/base_reset_problem_step.py b/apps/application/chat_pipeline/step/reset_problem_step/impl/base_reset_problem_step.py new file mode 100644 index 00000000000..ec01daa3444 --- /dev/null +++ b/apps/application/chat_pipeline/step/reset_problem_step/impl/base_reset_problem_step.py @@ -0,0 +1,68 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: base_reset_problem_step.py + @date:2024/1/10 14:35 + @desc: +""" +from typing import List + +from django.utils.translation import gettext as _ +from langchain.schema import HumanMessage + +from application.chat_pipeline.step.reset_problem_step.i_reset_problem_step import IResetProblemStep +from application.models import ChatRecord +from common.util.split_model import flat_map +from setting.models_provider.tools import get_model_instance_by_model_user_id + +prompt = _( + "() contains the user's question. Answer the guessed user's question based on the context ({question}) Requirement: Output a complete question and put it in the tag") + + +class BaseResetProblemStep(IResetProblemStep): + def execute(self, problem_text: str, history_chat_record: List[ChatRecord] = None, model_id: str = None, + problem_optimization_prompt=None, + user_id=None, + **kwargs) -> str: + chat_model = get_model_instance_by_model_user_id(model_id, user_id) if model_id is not None else None + if chat_model is None: + return problem_text + start_index = len(history_chat_record) - 3 + history_message = [[history_chat_record[index].get_human_message(), history_chat_record[index].get_ai_message()] + for index in + range(start_index if start_index > 0 else 0, len(history_chat_record))] + reset_prompt = problem_optimization_prompt if problem_optimization_prompt else prompt + message_list = [*flat_map(history_message), + HumanMessage(content=reset_prompt.replace('{question}', problem_text))] + response = chat_model.invoke(message_list) + padding_problem = problem_text + if response.content.__contains__("") and response.content.__contains__(''): + padding_problem_data = response.content[ + response.content.index('') + 6:response.content.index('')] + if padding_problem_data is not None and len(padding_problem_data.strip()) > 0: + padding_problem = padding_problem_data + elif len(response.content) > 0: + padding_problem = response.content + + try: + request_token = chat_model.get_num_tokens_from_messages(message_list) + response_token = chat_model.get_num_tokens(padding_problem) + except Exception as e: + request_token = 0 + response_token = 0 + self.context['message_tokens'] = request_token + self.context['answer_tokens'] = response_token + return padding_problem + + def get_details(self, manage, **kwargs): + return { + 'step_type': 'problem_padding', + 'run_time': self.context['run_time'], + 'model_id': str(manage.context['model_id']) if 'model_id' in manage.context else None, + 'message_tokens': self.context.get('message_tokens', 0), + 'answer_tokens': self.context.get('answer_tokens', 0), + 'cost': 0, + 'padding_problem_text': self.context.get('padding_problem_text'), + 'problem_text': self.context.get("step_args").get('problem_text'), + } diff --git a/apps/application/chat_pipeline/step/search_dataset_step/__init__.py b/apps/application/chat_pipeline/step/search_dataset_step/__init__.py new file mode 100644 index 00000000000..023c4bc387d --- /dev/null +++ b/apps/application/chat_pipeline/step/search_dataset_step/__init__.py @@ -0,0 +1,8 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: __init__.py.py + @date:2024/1/9 18:24 + @desc: +""" diff --git a/apps/application/chat_pipeline/step/search_dataset_step/i_search_dataset_step.py b/apps/application/chat_pipeline/step/search_dataset_step/i_search_dataset_step.py new file mode 100644 index 00000000000..7b222cbc279 --- /dev/null +++ b/apps/application/chat_pipeline/step/search_dataset_step/i_search_dataset_step.py @@ -0,0 +1,77 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: i_search_dataset_step.py + @date:2024/1/9 18:10 + @desc: 检索知识库 +""" +import re +from abc import abstractmethod +from typing import List, Type + +from django.core import validators +from django.utils.translation import gettext_lazy as _ +from rest_framework import serializers + +from application.chat_pipeline.I_base_chat_pipeline import IBaseChatPipelineStep, ParagraphPipelineModel +from application.chat_pipeline.pipeline_manage import PipelineManage +from common.util.field_message import ErrMessage + + +class ISearchDatasetStep(IBaseChatPipelineStep): + class InstanceSerializer(serializers.Serializer): + # 原始问题文本 + problem_text = serializers.CharField(required=True, error_messages=ErrMessage.char(_("question"))) + # 系统补全问题文本 + padding_problem_text = serializers.CharField(required=False, + error_messages=ErrMessage.char(_("System completes question text"))) + # 需要查询的数据集id列表 + dataset_id_list = serializers.ListField(required=True, child=serializers.UUIDField(required=True), + error_messages=ErrMessage.list(_("Dataset id list"))) + # 需要排除的文档id + exclude_document_id_list = serializers.ListField(required=True, child=serializers.UUIDField(required=True), + error_messages=ErrMessage.list(_("List of document ids to exclude"))) + # 需要排除向量id + exclude_paragraph_id_list = serializers.ListField(required=True, child=serializers.UUIDField(required=True), + error_messages=ErrMessage.list(_("List of exclusion vector ids"))) + # 需要查询的条数 + top_n = serializers.IntegerField(required=True, + error_messages=ErrMessage.integer(_("Reference segment number"))) + # 相似度 0-1之间 + similarity = serializers.FloatField(required=True, max_value=1, min_value=0, + error_messages=ErrMessage.float(_("Similarity"))) + search_mode = serializers.CharField(required=True, validators=[ + validators.RegexValidator(regex=re.compile("^embedding|keywords|blend$"), + message=_("The type only supports embedding|keywords|blend"), code=500) + ], error_messages=ErrMessage.char(_("Retrieval Mode"))) + user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid(_("User ID"))) + + def get_step_serializer(self, manage: PipelineManage) -> Type[InstanceSerializer]: + return self.InstanceSerializer + + def _run(self, manage: PipelineManage): + paragraph_list = self.execute(**self.context['step_args']) + manage.context['paragraph_list'] = paragraph_list + self.context['paragraph_list'] = paragraph_list + + @abstractmethod + def execute(self, problem_text: str, dataset_id_list: list[str], exclude_document_id_list: list[str], + exclude_paragraph_id_list: list[str], top_n: int, similarity: float, padding_problem_text: str = None, + search_mode: str = None, + user_id=None, + **kwargs) -> List[ParagraphPipelineModel]: + """ + 关于 用户和补全问题 说明: 补全问题如果有就使用补全问题去查询 反之就用用户原始问题查询 + :param similarity: 相关性 + :param top_n: 查询多少条 + :param problem_text: 用户问题 + :param dataset_id_list: 需要查询的数据集id列表 + :param exclude_document_id_list: 需要排除的文档id + :param exclude_paragraph_id_list: 需要排除段落id + :param padding_problem_text 补全问题 + :param search_mode 检索模式 + :param user_id 用户id + :return: 段落列表 + """ + pass diff --git a/apps/application/chat_pipeline/step/search_dataset_step/impl/base_search_dataset_step.py b/apps/application/chat_pipeline/step/search_dataset_step/impl/base_search_dataset_step.py new file mode 100644 index 00000000000..6591f6d246a --- /dev/null +++ b/apps/application/chat_pipeline/step/search_dataset_step/impl/base_search_dataset_step.py @@ -0,0 +1,138 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: base_search_dataset_step.py + @date:2024/1/10 10:33 + @desc: +""" +import os +from typing import List, Dict + +from django.db.models import QuerySet +from django.utils.translation import gettext_lazy as _ +from rest_framework.utils.formatting import lazy_format + +from application.chat_pipeline.I_base_chat_pipeline import ParagraphPipelineModel +from application.chat_pipeline.step.search_dataset_step.i_search_dataset_step import ISearchDatasetStep +from common.config.embedding_config import VectorStore, ModelManage +from common.db.search import native_search +from common.util.file_util import get_file_content +from dataset.models import Paragraph, DataSet +from embedding.models import SearchMode +from setting.models import Model +from setting.models_provider import get_model +from smartdoc.conf import PROJECT_DIR + + +def get_model_by_id(_id, user_id): + model = QuerySet(Model).filter(id=_id).first() + if model is None: + raise Exception(_("Model does not exist")) + if model.permission_type == 'PRIVATE' and str(model.user_id) != str(user_id): + message = lazy_format(_('No permission to use this model {model_name}'), model_name=model.name) + raise Exception(message) + return model + + +def get_embedding_id(dataset_id_list): + dataset_list = QuerySet(DataSet).filter(id__in=dataset_id_list) + if len(set([dataset.embedding_mode_id for dataset in dataset_list])) > 1: + raise Exception(_("The vector model of the associated knowledge base is inconsistent and the segmentation cannot be recalled.")) + if len(dataset_list) == 0: + raise Exception(_("The knowledge base setting is wrong, please reset the knowledge base")) + return dataset_list[0].embedding_mode_id + + +class BaseSearchDatasetStep(ISearchDatasetStep): + + def execute(self, problem_text: str, dataset_id_list: list[str], exclude_document_id_list: list[str], + exclude_paragraph_id_list: list[str], top_n: int, similarity: float, padding_problem_text: str = None, + search_mode: str = None, + user_id=None, + **kwargs) -> List[ParagraphPipelineModel]: + if len(dataset_id_list) == 0: + return [] + exec_problem_text = padding_problem_text if padding_problem_text is not None else problem_text + model_id = get_embedding_id(dataset_id_list) + model = get_model_by_id(model_id, user_id) + self.context['model_name'] = model.name + embedding_model = ModelManage.get_model(model_id, lambda _id: get_model(model)) + embedding_value = embedding_model.embed_query(exec_problem_text) + vector = VectorStore.get_embedding_vector() + embedding_list = vector.query(exec_problem_text, embedding_value, dataset_id_list, exclude_document_id_list, + exclude_paragraph_id_list, True, top_n, similarity, SearchMode(search_mode)) + if embedding_list is None: + return [] + paragraph_list = self.list_paragraph(embedding_list, vector) + result = [self.reset_paragraph(paragraph, embedding_list) for paragraph in paragraph_list] + return result + + @staticmethod + def reset_paragraph(paragraph: Dict, embedding_list: List) -> ParagraphPipelineModel: + filter_embedding_list = [embedding for embedding in embedding_list if + str(embedding.get('paragraph_id')) == str(paragraph.get('id'))] + if filter_embedding_list is not None and len(filter_embedding_list) > 0: + find_embedding = filter_embedding_list[-1] + return (ParagraphPipelineModel.builder() + .add_paragraph(paragraph) + .add_similarity(find_embedding.get('similarity')) + .add_comprehensive_score(find_embedding.get('comprehensive_score')) + .add_dataset_name(paragraph.get('dataset_name')) + .add_document_name(paragraph.get('document_name')) + .add_hit_handling_method(paragraph.get('hit_handling_method')) + .add_directly_return_similarity(paragraph.get('directly_return_similarity')) + .add_meta(paragraph.get('meta')) + .build()) + + @staticmethod + def get_similarity(paragraph, embedding_list: List): + filter_embedding_list = [embedding for embedding in embedding_list if + str(embedding.get('paragraph_id')) == str(paragraph.get('id'))] + if filter_embedding_list is not None and len(filter_embedding_list) > 0: + find_embedding = filter_embedding_list[-1] + return find_embedding.get('comprehensive_score') + return 0 + + @staticmethod + def list_paragraph(embedding_list: List, vector): + paragraph_id_list = [row.get('paragraph_id') for row in embedding_list] + if paragraph_id_list is None or len(paragraph_id_list) == 0: + return [] + paragraph_list = native_search(QuerySet(Paragraph).filter(id__in=paragraph_id_list), + get_file_content( + os.path.join(PROJECT_DIR, "apps", "application", 'sql', + 'list_dataset_paragraph_by_paragraph_id.sql')), + with_table_name=True) + # 如果向量库中存在脏数据 直接删除 + if len(paragraph_list) != len(paragraph_id_list): + exist_paragraph_list = [row.get('id') for row in paragraph_list] + for paragraph_id in paragraph_id_list: + if not exist_paragraph_list.__contains__(paragraph_id): + vector.delete_by_paragraph_id(paragraph_id) + # 如果存在直接返回的则取直接返回段落 + hit_handling_method_paragraph = [paragraph for paragraph in paragraph_list if + (paragraph.get( + 'hit_handling_method') == 'directly_return' and BaseSearchDatasetStep.get_similarity( + paragraph, embedding_list) >= paragraph.get( + 'directly_return_similarity'))] + if len(hit_handling_method_paragraph) > 0: + # 找到评分最高的 + return [sorted(hit_handling_method_paragraph, + key=lambda p: BaseSearchDatasetStep.get_similarity(p, embedding_list))[-1]] + return paragraph_list + + def get_details(self, manage, **kwargs): + step_args = self.context['step_args'] + + return { + 'step_type': 'search_step', + 'paragraph_list': [row.to_dict() for row in self.context['paragraph_list']], + 'run_time': self.context['run_time'], + 'problem_text': step_args.get( + 'padding_problem_text') if 'padding_problem_text' in step_args else step_args.get('problem_text'), + 'model_name': self.context.get('model_name'), + 'message_tokens': 0, + 'answer_tokens': 0, + 'cost': 0 + } diff --git a/apps/application/flow/__init__.py b/apps/application/flow/__init__.py new file mode 100644 index 00000000000..328e8f8ec5f --- /dev/null +++ b/apps/application/flow/__init__.py @@ -0,0 +1,8 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: __init__.py.py + @date:2024/6/7 14:43 + @desc: +""" diff --git a/apps/application/flow/common.py b/apps/application/flow/common.py new file mode 100644 index 00000000000..f5d4cb9b0f7 --- /dev/null +++ b/apps/application/flow/common.py @@ -0,0 +1,44 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: common.py + @date:2024/12/11 17:57 + @desc: +""" + + +class Answer: + def __init__(self, content, view_type, runtime_node_id, chat_record_id, child_node, real_node_id, + reasoning_content): + self.view_type = view_type + self.content = content + self.reasoning_content = reasoning_content + self.runtime_node_id = runtime_node_id + self.chat_record_id = chat_record_id + self.child_node = child_node + self.real_node_id = real_node_id + + def to_dict(self): + return {'view_type': self.view_type, 'content': self.content, 'runtime_node_id': self.runtime_node_id, + 'chat_record_id': self.chat_record_id, + 'child_node': self.child_node, + 'reasoning_content': self.reasoning_content, + 'real_node_id': self.real_node_id} + + +class NodeChunk: + def __init__(self): + self.status = 0 + self.chunk_list = [] + + def add_chunk(self, chunk): + self.chunk_list.append(chunk) + + def end(self, chunk=None): + if chunk is not None: + self.add_chunk(chunk) + self.status = 200 + + def is_end(self): + return self.status == 200 diff --git a/apps/application/flow/default_workflow.json b/apps/application/flow/default_workflow.json new file mode 100644 index 00000000000..48ac23c4dc6 --- /dev/null +++ b/apps/application/flow/default_workflow.json @@ -0,0 +1,451 @@ +{ + "nodes": [ + { + "id": "base-node", + "type": "base-node", + "x": 360, + "y": 2810, + "properties": { + "config": { + + }, + "height": 825.6, + "stepName": "基本信息", + "node_data": { + "desc": "", + "name": "maxkbapplication", + "prologue": "您好,我是 MaxKB 小助手,您可以向我提出 MaxKB 使用问题。\n- MaxKB 主要功能有什么?\n- MaxKB 支持哪些大语言模型?\n- MaxKB 支持哪些文档类型?" + }, + "input_field_list": [ + + ] + } + }, + { + "id": "start-node", + "type": "start-node", + "x": 430, + "y": 3660, + "properties": { + "config": { + "fields": [ + { + "label": "用户问题", + "value": "question" + } + ], + "globalFields": [ + { + "label": "当前时间", + "value": "time" + } + ] + }, + "fields": [ + { + "label": "用户问题", + "value": "question" + } + ], + "height": 276, + "stepName": "开始", + "globalFields": [ + { + "label": "当前时间", + "value": "time" + } + ] + } + }, + { + "id": "b931efe5-5b66-46e0-ae3b-0160cb18eeb5", + "type": "search-dataset-node", + "x": 840, + "y": 3210, + "properties": { + "config": { + "fields": [ + { + "label": "检索结果的分段列表", + "value": "paragraph_list" + }, + { + "label": "满足直接回答的分段列表", + "value": "is_hit_handling_method_list" + }, + { + "label": "检索结果", + "value": "data" + }, + { + "label": "满足直接回答的分段内容", + "value": "directly_return" + } + ] + }, + "height": 794, + "stepName": "知识库检索", + "node_data": { + "dataset_id_list": [ + + ], + "dataset_setting": { + "top_n": 3, + "similarity": 0.6, + "search_mode": "embedding", + "max_paragraph_char_number": 5000 + }, + "question_reference_address": [ + "start-node", + "question" + ], + "source_dataset_id_list": [ + + ] + } + } + }, + { + "id": "fc60863a-dec2-4854-9e5a-7a44b7187a2b", + "type": "condition-node", + "x": 1490, + "y": 3210, + "properties": { + "width": 600, + "config": { + "fields": [ + { + "label": "分支名称", + "value": "branch_name" + } + ] + }, + "height": 543.675, + "stepName": "判断器", + "node_data": { + "branch": [ + { + "id": "1009", + "type": "IF", + "condition": "and", + "conditions": [ + { + "field": [ + "b931efe5-5b66-46e0-ae3b-0160cb18eeb5", + "is_hit_handling_method_list" + ], + "value": "1", + "compare": "len_ge" + } + ] + }, + { + "id": "4908", + "type": "ELSE IF 1", + "condition": "and", + "conditions": [ + { + "field": [ + "b931efe5-5b66-46e0-ae3b-0160cb18eeb5", + "paragraph_list" + ], + "value": "1", + "compare": "len_ge" + } + ] + }, + { + "id": "161", + "type": "ELSE", + "condition": "and", + "conditions": [ + + ] + } + ] + }, + "branch_condition_list": [ + { + "index": 0, + "height": 121.225, + "id": "1009" + }, + { + "index": 1, + "height": 121.225, + "id": "4908" + }, + { + "index": 2, + "height": 44, + "id": "161" + } + ] + } + }, + { + "id": "4ffe1086-25df-4c85-b168-979b5bbf0a26", + "type": "reply-node", + "x": 2170, + "y": 2480, + "properties": { + "config": { + "fields": [ + { + "label": "内容", + "value": "answer" + } + ] + }, + "height": 378, + "stepName": "指定回复", + "node_data": { + "fields": [ + "b931efe5-5b66-46e0-ae3b-0160cb18eeb5", + "directly_return" + ], + "content": "", + "reply_type": "referencing", + "is_result": true + } + } + }, + { + "id": "f1f1ee18-5a02-46f6-b4e6-226253cdffbb", + "type": "ai-chat-node", + "x": 2160, + "y": 3200, + "properties": { + "config": { + "fields": [ + { + "label": "AI 回答内容", + "value": "answer" + } + ] + }, + "height": 763, + "stepName": "AI 对话", + "node_data": { + "prompt": "已知信息:\n{{知识库检索.data}}\n问题:\n{{开始.question}}", + "system": "", + "model_id": "", + "dialogue_number": 0, + "is_result": true + } + } + }, + { + "id": "309d0eef-c597-46b5-8d51-b9a28aaef4c7", + "type": "ai-chat-node", + "x": 2160, + "y": 3970, + "properties": { + "config": { + "fields": [ + { + "label": "AI 回答内容", + "value": "answer" + } + ] + }, + "height": 763, + "stepName": "AI 对话1", + "node_data": { + "prompt": "{{开始.question}}", + "system": "", + "model_id": "", + "dialogue_number": 0, + "is_result": true + } + } + } + ], + "edges": [ + { + "id": "7d0f166f-c472-41b2-b9a2-c294f4c83d73", + "type": "app-edge", + "sourceNodeId": "start-node", + "targetNodeId": "b931efe5-5b66-46e0-ae3b-0160cb18eeb5", + "startPoint": { + "x": 590, + "y": 3660 + }, + "endPoint": { + "x": 680, + "y": 3210 + }, + "properties": { + + }, + "pointsList": [ + { + "x": 590, + "y": 3660 + }, + { + "x": 700, + "y": 3660 + }, + { + "x": 570, + "y": 3210 + }, + { + "x": 680, + "y": 3210 + } + ], + "sourceAnchorId": "start-node_right", + "targetAnchorId": "b931efe5-5b66-46e0-ae3b-0160cb18eeb5_left" + }, + { + "id": "35cb86dd-f328-429e-a973-12fd7218b696", + "type": "app-edge", + "sourceNodeId": "b931efe5-5b66-46e0-ae3b-0160cb18eeb5", + "targetNodeId": "fc60863a-dec2-4854-9e5a-7a44b7187a2b", + "startPoint": { + "x": 1000, + "y": 3210 + }, + "endPoint": { + "x": 1200, + "y": 3210 + }, + "properties": { + + }, + "pointsList": [ + { + "x": 1000, + "y": 3210 + }, + { + "x": 1110, + "y": 3210 + }, + { + "x": 1090, + "y": 3210 + }, + { + "x": 1200, + "y": 3210 + } + ], + "sourceAnchorId": "b931efe5-5b66-46e0-ae3b-0160cb18eeb5_right", + "targetAnchorId": "fc60863a-dec2-4854-9e5a-7a44b7187a2b_left" + }, + { + "id": "e8f6cfe6-7e48-41cd-abd3-abfb5304d0d8", + "type": "app-edge", + "sourceNodeId": "fc60863a-dec2-4854-9e5a-7a44b7187a2b", + "targetNodeId": "4ffe1086-25df-4c85-b168-979b5bbf0a26", + "startPoint": { + "x": 1780, + "y": 3073.775 + }, + "endPoint": { + "x": 2010, + "y": 2480 + }, + "properties": { + + }, + "pointsList": [ + { + "x": 1780, + "y": 3073.775 + }, + { + "x": 1890, + "y": 3073.775 + }, + { + "x": 1900, + "y": 2480 + }, + { + "x": 2010, + "y": 2480 + } + ], + "sourceAnchorId": "fc60863a-dec2-4854-9e5a-7a44b7187a2b_1009_right", + "targetAnchorId": "4ffe1086-25df-4c85-b168-979b5bbf0a26_left" + }, + { + "id": "994ff325-6f7a-4ebc-b61b-10e15519d6d2", + "type": "app-edge", + "sourceNodeId": "fc60863a-dec2-4854-9e5a-7a44b7187a2b", + "targetNodeId": "f1f1ee18-5a02-46f6-b4e6-226253cdffbb", + "startPoint": { + "x": 1780, + "y": 3203 + }, + "endPoint": { + "x": 2000, + "y": 3200 + }, + "properties": { + + }, + "pointsList": [ + { + "x": 1780, + "y": 3203 + }, + { + "x": 1890, + "y": 3203 + }, + { + "x": 1890, + "y": 3200 + }, + { + "x": 2000, + "y": 3200 + } + ], + "sourceAnchorId": "fc60863a-dec2-4854-9e5a-7a44b7187a2b_4908_right", + "targetAnchorId": "f1f1ee18-5a02-46f6-b4e6-226253cdffbb_left" + }, + { + "id": "19270caf-bb9f-4ba7-9bf8-200aa70fecd5", + "type": "app-edge", + "sourceNodeId": "fc60863a-dec2-4854-9e5a-7a44b7187a2b", + "targetNodeId": "309d0eef-c597-46b5-8d51-b9a28aaef4c7", + "startPoint": { + "x": 1780, + "y": 3293.6124999999997 + }, + "endPoint": { + "x": 2000, + "y": 3970 + }, + "properties": { + + }, + "pointsList": [ + { + "x": 1780, + "y": 3293.6124999999997 + }, + { + "x": 1890, + "y": 3293.6124999999997 + }, + { + "x": 1890, + "y": 3970 + }, + { + "x": 2000, + "y": 3970 + } + ], + "sourceAnchorId": "fc60863a-dec2-4854-9e5a-7a44b7187a2b_161_right", + "targetAnchorId": "309d0eef-c597-46b5-8d51-b9a28aaef4c7_left" + } + ] +} \ No newline at end of file diff --git a/apps/application/flow/default_workflow_en.json b/apps/application/flow/default_workflow_en.json new file mode 100644 index 00000000000..7c0194be676 --- /dev/null +++ b/apps/application/flow/default_workflow_en.json @@ -0,0 +1,451 @@ +{ + "nodes": [ + { + "id": "base-node", + "type": "base-node", + "x": 360, + "y": 2810, + "properties": { + "config": { + + }, + "height": 825.6, + "stepName": "Base", + "node_data": { + "desc": "", + "name": "maxkbapplication", + "prologue": "Hello, I am the MaxKB assistant. You can ask me about MaxKB usage issues.\n-What are the main functions of MaxKB?\n-What major language models does MaxKB support?\n-What document types does MaxKB support?" + }, + "input_field_list": [ + + ] + } + }, + { + "id": "start-node", + "type": "start-node", + "x": 430, + "y": 3660, + "properties": { + "config": { + "fields": [ + { + "label": "用户问题", + "value": "question" + } + ], + "globalFields": [ + { + "label": "当前时间", + "value": "time" + } + ] + }, + "fields": [ + { + "label": "用户问题", + "value": "question" + } + ], + "height": 276, + "stepName": "Start", + "globalFields": [ + { + "label": "当前时间", + "value": "time" + } + ] + } + }, + { + "id": "b931efe5-5b66-46e0-ae3b-0160cb18eeb5", + "type": "search-dataset-node", + "x": 840, + "y": 3210, + "properties": { + "config": { + "fields": [ + { + "label": "检索结果的分段列表", + "value": "paragraph_list" + }, + { + "label": "满足直接回答的分段列表", + "value": "is_hit_handling_method_list" + }, + { + "label": "检索结果", + "value": "data" + }, + { + "label": "满足直接回答的分段内容", + "value": "directly_return" + } + ] + }, + "height": 794, + "stepName": "Knowledge Search", + "node_data": { + "dataset_id_list": [ + + ], + "dataset_setting": { + "top_n": 3, + "similarity": 0.6, + "search_mode": "embedding", + "max_paragraph_char_number": 5000 + }, + "question_reference_address": [ + "start-node", + "question" + ], + "source_dataset_id_list": [ + + ] + } + } + }, + { + "id": "fc60863a-dec2-4854-9e5a-7a44b7187a2b", + "type": "condition-node", + "x": 1490, + "y": 3210, + "properties": { + "width": 600, + "config": { + "fields": [ + { + "label": "分支名称", + "value": "branch_name" + } + ] + }, + "height": 543.675, + "stepName": "Conditional Branch", + "node_data": { + "branch": [ + { + "id": "1009", + "type": "IF", + "condition": "and", + "conditions": [ + { + "field": [ + "b931efe5-5b66-46e0-ae3b-0160cb18eeb5", + "is_hit_handling_method_list" + ], + "value": "1", + "compare": "len_ge" + } + ] + }, + { + "id": "4908", + "type": "ELSE IF 1", + "condition": "and", + "conditions": [ + { + "field": [ + "b931efe5-5b66-46e0-ae3b-0160cb18eeb5", + "paragraph_list" + ], + "value": "1", + "compare": "len_ge" + } + ] + }, + { + "id": "161", + "type": "ELSE", + "condition": "and", + "conditions": [ + + ] + } + ] + }, + "branch_condition_list": [ + { + "index": 0, + "height": 121.225, + "id": "1009" + }, + { + "index": 1, + "height": 121.225, + "id": "4908" + }, + { + "index": 2, + "height": 44, + "id": "161" + } + ] + } + }, + { + "id": "4ffe1086-25df-4c85-b168-979b5bbf0a26", + "type": "reply-node", + "x": 2170, + "y": 2480, + "properties": { + "config": { + "fields": [ + { + "label": "内容", + "value": "answer" + } + ] + }, + "height": 378, + "stepName": "Specified Reply", + "node_data": { + "fields": [ + "b931efe5-5b66-46e0-ae3b-0160cb18eeb5", + "directly_return" + ], + "content": "", + "reply_type": "referencing", + "is_result": true + } + } + }, + { + "id": "f1f1ee18-5a02-46f6-b4e6-226253cdffbb", + "type": "ai-chat-node", + "x": 2160, + "y": 3200, + "properties": { + "config": { + "fields": [ + { + "label": "AI 回答内容", + "value": "answer" + } + ] + }, + "height": 763, + "stepName": "AI Chat", + "node_data": { + "prompt": "Known information:\n{{Knowledge Search.data}}\nQuestion:\n{{Start.question}}", + "system": "", + "model_id": "", + "dialogue_number": 0, + "is_result": true + } + } + }, + { + "id": "309d0eef-c597-46b5-8d51-b9a28aaef4c7", + "type": "ai-chat-node", + "x": 2160, + "y": 3970, + "properties": { + "config": { + "fields": [ + { + "label": "AI 回答内容", + "value": "answer" + } + ] + }, + "height": 763, + "stepName": "AI Chat1", + "node_data": { + "prompt": "{{Start.question}}", + "system": "", + "model_id": "", + "dialogue_number": 0, + "is_result": true + } + } + } + ], + "edges": [ + { + "id": "7d0f166f-c472-41b2-b9a2-c294f4c83d73", + "type": "app-edge", + "sourceNodeId": "start-node", + "targetNodeId": "b931efe5-5b66-46e0-ae3b-0160cb18eeb5", + "startPoint": { + "x": 590, + "y": 3660 + }, + "endPoint": { + "x": 680, + "y": 3210 + }, + "properties": { + + }, + "pointsList": [ + { + "x": 590, + "y": 3660 + }, + { + "x": 700, + "y": 3660 + }, + { + "x": 570, + "y": 3210 + }, + { + "x": 680, + "y": 3210 + } + ], + "sourceAnchorId": "start-node_right", + "targetAnchorId": "b931efe5-5b66-46e0-ae3b-0160cb18eeb5_left" + }, + { + "id": "35cb86dd-f328-429e-a973-12fd7218b696", + "type": "app-edge", + "sourceNodeId": "b931efe5-5b66-46e0-ae3b-0160cb18eeb5", + "targetNodeId": "fc60863a-dec2-4854-9e5a-7a44b7187a2b", + "startPoint": { + "x": 1000, + "y": 3210 + }, + "endPoint": { + "x": 1200, + "y": 3210 + }, + "properties": { + + }, + "pointsList": [ + { + "x": 1000, + "y": 3210 + }, + { + "x": 1110, + "y": 3210 + }, + { + "x": 1090, + "y": 3210 + }, + { + "x": 1200, + "y": 3210 + } + ], + "sourceAnchorId": "b931efe5-5b66-46e0-ae3b-0160cb18eeb5_right", + "targetAnchorId": "fc60863a-dec2-4854-9e5a-7a44b7187a2b_left" + }, + { + "id": "e8f6cfe6-7e48-41cd-abd3-abfb5304d0d8", + "type": "app-edge", + "sourceNodeId": "fc60863a-dec2-4854-9e5a-7a44b7187a2b", + "targetNodeId": "4ffe1086-25df-4c85-b168-979b5bbf0a26", + "startPoint": { + "x": 1780, + "y": 3073.775 + }, + "endPoint": { + "x": 2010, + "y": 2480 + }, + "properties": { + + }, + "pointsList": [ + { + "x": 1780, + "y": 3073.775 + }, + { + "x": 1890, + "y": 3073.775 + }, + { + "x": 1900, + "y": 2480 + }, + { + "x": 2010, + "y": 2480 + } + ], + "sourceAnchorId": "fc60863a-dec2-4854-9e5a-7a44b7187a2b_1009_right", + "targetAnchorId": "4ffe1086-25df-4c85-b168-979b5bbf0a26_left" + }, + { + "id": "994ff325-6f7a-4ebc-b61b-10e15519d6d2", + "type": "app-edge", + "sourceNodeId": "fc60863a-dec2-4854-9e5a-7a44b7187a2b", + "targetNodeId": "f1f1ee18-5a02-46f6-b4e6-226253cdffbb", + "startPoint": { + "x": 1780, + "y": 3203 + }, + "endPoint": { + "x": 2000, + "y": 3200 + }, + "properties": { + + }, + "pointsList": [ + { + "x": 1780, + "y": 3203 + }, + { + "x": 1890, + "y": 3203 + }, + { + "x": 1890, + "y": 3200 + }, + { + "x": 2000, + "y": 3200 + } + ], + "sourceAnchorId": "fc60863a-dec2-4854-9e5a-7a44b7187a2b_4908_right", + "targetAnchorId": "f1f1ee18-5a02-46f6-b4e6-226253cdffbb_left" + }, + { + "id": "19270caf-bb9f-4ba7-9bf8-200aa70fecd5", + "type": "app-edge", + "sourceNodeId": "fc60863a-dec2-4854-9e5a-7a44b7187a2b", + "targetNodeId": "309d0eef-c597-46b5-8d51-b9a28aaef4c7", + "startPoint": { + "x": 1780, + "y": 3293.6124999999997 + }, + "endPoint": { + "x": 2000, + "y": 3970 + }, + "properties": { + + }, + "pointsList": [ + { + "x": 1780, + "y": 3293.6124999999997 + }, + { + "x": 1890, + "y": 3293.6124999999997 + }, + { + "x": 1890, + "y": 3970 + }, + { + "x": 2000, + "y": 3970 + } + ], + "sourceAnchorId": "fc60863a-dec2-4854-9e5a-7a44b7187a2b_161_right", + "targetAnchorId": "309d0eef-c597-46b5-8d51-b9a28aaef4c7_left" + } + ] +} \ No newline at end of file diff --git a/apps/application/flow/default_workflow_zh.json b/apps/application/flow/default_workflow_zh.json new file mode 100644 index 00000000000..48ac23c4dc6 --- /dev/null +++ b/apps/application/flow/default_workflow_zh.json @@ -0,0 +1,451 @@ +{ + "nodes": [ + { + "id": "base-node", + "type": "base-node", + "x": 360, + "y": 2810, + "properties": { + "config": { + + }, + "height": 825.6, + "stepName": "基本信息", + "node_data": { + "desc": "", + "name": "maxkbapplication", + "prologue": "您好,我是 MaxKB 小助手,您可以向我提出 MaxKB 使用问题。\n- MaxKB 主要功能有什么?\n- MaxKB 支持哪些大语言模型?\n- MaxKB 支持哪些文档类型?" + }, + "input_field_list": [ + + ] + } + }, + { + "id": "start-node", + "type": "start-node", + "x": 430, + "y": 3660, + "properties": { + "config": { + "fields": [ + { + "label": "用户问题", + "value": "question" + } + ], + "globalFields": [ + { + "label": "当前时间", + "value": "time" + } + ] + }, + "fields": [ + { + "label": "用户问题", + "value": "question" + } + ], + "height": 276, + "stepName": "开始", + "globalFields": [ + { + "label": "当前时间", + "value": "time" + } + ] + } + }, + { + "id": "b931efe5-5b66-46e0-ae3b-0160cb18eeb5", + "type": "search-dataset-node", + "x": 840, + "y": 3210, + "properties": { + "config": { + "fields": [ + { + "label": "检索结果的分段列表", + "value": "paragraph_list" + }, + { + "label": "满足直接回答的分段列表", + "value": "is_hit_handling_method_list" + }, + { + "label": "检索结果", + "value": "data" + }, + { + "label": "满足直接回答的分段内容", + "value": "directly_return" + } + ] + }, + "height": 794, + "stepName": "知识库检索", + "node_data": { + "dataset_id_list": [ + + ], + "dataset_setting": { + "top_n": 3, + "similarity": 0.6, + "search_mode": "embedding", + "max_paragraph_char_number": 5000 + }, + "question_reference_address": [ + "start-node", + "question" + ], + "source_dataset_id_list": [ + + ] + } + } + }, + { + "id": "fc60863a-dec2-4854-9e5a-7a44b7187a2b", + "type": "condition-node", + "x": 1490, + "y": 3210, + "properties": { + "width": 600, + "config": { + "fields": [ + { + "label": "分支名称", + "value": "branch_name" + } + ] + }, + "height": 543.675, + "stepName": "判断器", + "node_data": { + "branch": [ + { + "id": "1009", + "type": "IF", + "condition": "and", + "conditions": [ + { + "field": [ + "b931efe5-5b66-46e0-ae3b-0160cb18eeb5", + "is_hit_handling_method_list" + ], + "value": "1", + "compare": "len_ge" + } + ] + }, + { + "id": "4908", + "type": "ELSE IF 1", + "condition": "and", + "conditions": [ + { + "field": [ + "b931efe5-5b66-46e0-ae3b-0160cb18eeb5", + "paragraph_list" + ], + "value": "1", + "compare": "len_ge" + } + ] + }, + { + "id": "161", + "type": "ELSE", + "condition": "and", + "conditions": [ + + ] + } + ] + }, + "branch_condition_list": [ + { + "index": 0, + "height": 121.225, + "id": "1009" + }, + { + "index": 1, + "height": 121.225, + "id": "4908" + }, + { + "index": 2, + "height": 44, + "id": "161" + } + ] + } + }, + { + "id": "4ffe1086-25df-4c85-b168-979b5bbf0a26", + "type": "reply-node", + "x": 2170, + "y": 2480, + "properties": { + "config": { + "fields": [ + { + "label": "内容", + "value": "answer" + } + ] + }, + "height": 378, + "stepName": "指定回复", + "node_data": { + "fields": [ + "b931efe5-5b66-46e0-ae3b-0160cb18eeb5", + "directly_return" + ], + "content": "", + "reply_type": "referencing", + "is_result": true + } + } + }, + { + "id": "f1f1ee18-5a02-46f6-b4e6-226253cdffbb", + "type": "ai-chat-node", + "x": 2160, + "y": 3200, + "properties": { + "config": { + "fields": [ + { + "label": "AI 回答内容", + "value": "answer" + } + ] + }, + "height": 763, + "stepName": "AI 对话", + "node_data": { + "prompt": "已知信息:\n{{知识库检索.data}}\n问题:\n{{开始.question}}", + "system": "", + "model_id": "", + "dialogue_number": 0, + "is_result": true + } + } + }, + { + "id": "309d0eef-c597-46b5-8d51-b9a28aaef4c7", + "type": "ai-chat-node", + "x": 2160, + "y": 3970, + "properties": { + "config": { + "fields": [ + { + "label": "AI 回答内容", + "value": "answer" + } + ] + }, + "height": 763, + "stepName": "AI 对话1", + "node_data": { + "prompt": "{{开始.question}}", + "system": "", + "model_id": "", + "dialogue_number": 0, + "is_result": true + } + } + } + ], + "edges": [ + { + "id": "7d0f166f-c472-41b2-b9a2-c294f4c83d73", + "type": "app-edge", + "sourceNodeId": "start-node", + "targetNodeId": "b931efe5-5b66-46e0-ae3b-0160cb18eeb5", + "startPoint": { + "x": 590, + "y": 3660 + }, + "endPoint": { + "x": 680, + "y": 3210 + }, + "properties": { + + }, + "pointsList": [ + { + "x": 590, + "y": 3660 + }, + { + "x": 700, + "y": 3660 + }, + { + "x": 570, + "y": 3210 + }, + { + "x": 680, + "y": 3210 + } + ], + "sourceAnchorId": "start-node_right", + "targetAnchorId": "b931efe5-5b66-46e0-ae3b-0160cb18eeb5_left" + }, + { + "id": "35cb86dd-f328-429e-a973-12fd7218b696", + "type": "app-edge", + "sourceNodeId": "b931efe5-5b66-46e0-ae3b-0160cb18eeb5", + "targetNodeId": "fc60863a-dec2-4854-9e5a-7a44b7187a2b", + "startPoint": { + "x": 1000, + "y": 3210 + }, + "endPoint": { + "x": 1200, + "y": 3210 + }, + "properties": { + + }, + "pointsList": [ + { + "x": 1000, + "y": 3210 + }, + { + "x": 1110, + "y": 3210 + }, + { + "x": 1090, + "y": 3210 + }, + { + "x": 1200, + "y": 3210 + } + ], + "sourceAnchorId": "b931efe5-5b66-46e0-ae3b-0160cb18eeb5_right", + "targetAnchorId": "fc60863a-dec2-4854-9e5a-7a44b7187a2b_left" + }, + { + "id": "e8f6cfe6-7e48-41cd-abd3-abfb5304d0d8", + "type": "app-edge", + "sourceNodeId": "fc60863a-dec2-4854-9e5a-7a44b7187a2b", + "targetNodeId": "4ffe1086-25df-4c85-b168-979b5bbf0a26", + "startPoint": { + "x": 1780, + "y": 3073.775 + }, + "endPoint": { + "x": 2010, + "y": 2480 + }, + "properties": { + + }, + "pointsList": [ + { + "x": 1780, + "y": 3073.775 + }, + { + "x": 1890, + "y": 3073.775 + }, + { + "x": 1900, + "y": 2480 + }, + { + "x": 2010, + "y": 2480 + } + ], + "sourceAnchorId": "fc60863a-dec2-4854-9e5a-7a44b7187a2b_1009_right", + "targetAnchorId": "4ffe1086-25df-4c85-b168-979b5bbf0a26_left" + }, + { + "id": "994ff325-6f7a-4ebc-b61b-10e15519d6d2", + "type": "app-edge", + "sourceNodeId": "fc60863a-dec2-4854-9e5a-7a44b7187a2b", + "targetNodeId": "f1f1ee18-5a02-46f6-b4e6-226253cdffbb", + "startPoint": { + "x": 1780, + "y": 3203 + }, + "endPoint": { + "x": 2000, + "y": 3200 + }, + "properties": { + + }, + "pointsList": [ + { + "x": 1780, + "y": 3203 + }, + { + "x": 1890, + "y": 3203 + }, + { + "x": 1890, + "y": 3200 + }, + { + "x": 2000, + "y": 3200 + } + ], + "sourceAnchorId": "fc60863a-dec2-4854-9e5a-7a44b7187a2b_4908_right", + "targetAnchorId": "f1f1ee18-5a02-46f6-b4e6-226253cdffbb_left" + }, + { + "id": "19270caf-bb9f-4ba7-9bf8-200aa70fecd5", + "type": "app-edge", + "sourceNodeId": "fc60863a-dec2-4854-9e5a-7a44b7187a2b", + "targetNodeId": "309d0eef-c597-46b5-8d51-b9a28aaef4c7", + "startPoint": { + "x": 1780, + "y": 3293.6124999999997 + }, + "endPoint": { + "x": 2000, + "y": 3970 + }, + "properties": { + + }, + "pointsList": [ + { + "x": 1780, + "y": 3293.6124999999997 + }, + { + "x": 1890, + "y": 3293.6124999999997 + }, + { + "x": 1890, + "y": 3970 + }, + { + "x": 2000, + "y": 3970 + } + ], + "sourceAnchorId": "fc60863a-dec2-4854-9e5a-7a44b7187a2b_161_right", + "targetAnchorId": "309d0eef-c597-46b5-8d51-b9a28aaef4c7_left" + } + ] +} \ No newline at end of file diff --git a/apps/application/flow/default_workflow_zh_Hant.json b/apps/application/flow/default_workflow_zh_Hant.json new file mode 100644 index 00000000000..b06301533d2 --- /dev/null +++ b/apps/application/flow/default_workflow_zh_Hant.json @@ -0,0 +1,451 @@ +{ + "nodes": [ + { + "id": "base-node", + "type": "base-node", + "x": 360, + "y": 2810, + "properties": { + "config": { + + }, + "height": 825.6, + "stepName": "基本資訊", + "node_data": { + "desc": "", + "name": "maxkbapplication", + "prologue": "您好,我是MaxKB小助手,您可以向我提出MaxKB使用問題。\n- MaxKB主要功能有什麼?\n- MaxKB支持哪些大語言模型?\n- MaxKB支持哪些文檔類型?" + }, + "input_field_list": [ + + ] + } + }, + { + "id": "start-node", + "type": "start-node", + "x": 430, + "y": 3660, + "properties": { + "config": { + "fields": [ + { + "label": "用户问题", + "value": "question" + } + ], + "globalFields": [ + { + "label": "当前时间", + "value": "time" + } + ] + }, + "fields": [ + { + "label": "用户问题", + "value": "question" + } + ], + "height": 276, + "stepName": "開始", + "globalFields": [ + { + "label": "当前时间", + "value": "time" + } + ] + } + }, + { + "id": "b931efe5-5b66-46e0-ae3b-0160cb18eeb5", + "type": "search-dataset-node", + "x": 840, + "y": 3210, + "properties": { + "config": { + "fields": [ + { + "label": "检索结果的分段列表", + "value": "paragraph_list" + }, + { + "label": "满足直接回答的分段列表", + "value": "is_hit_handling_method_list" + }, + { + "label": "检索结果", + "value": "data" + }, + { + "label": "满足直接回答的分段内容", + "value": "directly_return" + } + ] + }, + "height": 794, + "stepName": "知識庫檢索", + "node_data": { + "dataset_id_list": [ + + ], + "dataset_setting": { + "top_n": 3, + "similarity": 0.6, + "search_mode": "embedding", + "max_paragraph_char_number": 5000 + }, + "question_reference_address": [ + "start-node", + "question" + ], + "source_dataset_id_list": [ + + ] + } + } + }, + { + "id": "fc60863a-dec2-4854-9e5a-7a44b7187a2b", + "type": "condition-node", + "x": 1490, + "y": 3210, + "properties": { + "width": 600, + "config": { + "fields": [ + { + "label": "分支名称", + "value": "branch_name" + } + ] + }, + "height": 543.675, + "stepName": "判斷器", + "node_data": { + "branch": [ + { + "id": "1009", + "type": "IF", + "condition": "and", + "conditions": [ + { + "field": [ + "b931efe5-5b66-46e0-ae3b-0160cb18eeb5", + "is_hit_handling_method_list" + ], + "value": "1", + "compare": "len_ge" + } + ] + }, + { + "id": "4908", + "type": "ELSE IF 1", + "condition": "and", + "conditions": [ + { + "field": [ + "b931efe5-5b66-46e0-ae3b-0160cb18eeb5", + "paragraph_list" + ], + "value": "1", + "compare": "len_ge" + } + ] + }, + { + "id": "161", + "type": "ELSE", + "condition": "and", + "conditions": [ + + ] + } + ] + }, + "branch_condition_list": [ + { + "index": 0, + "height": 121.225, + "id": "1009" + }, + { + "index": 1, + "height": 121.225, + "id": "4908" + }, + { + "index": 2, + "height": 44, + "id": "161" + } + ] + } + }, + { + "id": "4ffe1086-25df-4c85-b168-979b5bbf0a26", + "type": "reply-node", + "x": 2170, + "y": 2480, + "properties": { + "config": { + "fields": [ + { + "label": "内容", + "value": "answer" + } + ] + }, + "height": 378, + "stepName": "指定回覆", + "node_data": { + "fields": [ + "b931efe5-5b66-46e0-ae3b-0160cb18eeb5", + "directly_return" + ], + "content": "", + "reply_type": "referencing", + "is_result": true + } + } + }, + { + "id": "f1f1ee18-5a02-46f6-b4e6-226253cdffbb", + "type": "ai-chat-node", + "x": 2160, + "y": 3200, + "properties": { + "config": { + "fields": [ + { + "label": "AI 回答内容", + "value": "answer" + } + ] + }, + "height": 763, + "stepName": "AI 對話", + "node_data": { + "prompt": "已知資訊:\n{{知識庫檢索.data}}\n問題:\n{{開始.question}}", + "system": "", + "model_id": "", + "dialogue_number": 0, + "is_result": true + } + } + }, + { + "id": "309d0eef-c597-46b5-8d51-b9a28aaef4c7", + "type": "ai-chat-node", + "x": 2160, + "y": 3970, + "properties": { + "config": { + "fields": [ + { + "label": "AI 回答内容", + "value": "answer" + } + ] + }, + "height": 763, + "stepName": "AI 對話1", + "node_data": { + "prompt": "{{開始.question}}", + "system": "", + "model_id": "", + "dialogue_number": 0, + "is_result": true + } + } + } + ], + "edges": [ + { + "id": "7d0f166f-c472-41b2-b9a2-c294f4c83d73", + "type": "app-edge", + "sourceNodeId": "start-node", + "targetNodeId": "b931efe5-5b66-46e0-ae3b-0160cb18eeb5", + "startPoint": { + "x": 590, + "y": 3660 + }, + "endPoint": { + "x": 680, + "y": 3210 + }, + "properties": { + + }, + "pointsList": [ + { + "x": 590, + "y": 3660 + }, + { + "x": 700, + "y": 3660 + }, + { + "x": 570, + "y": 3210 + }, + { + "x": 680, + "y": 3210 + } + ], + "sourceAnchorId": "start-node_right", + "targetAnchorId": "b931efe5-5b66-46e0-ae3b-0160cb18eeb5_left" + }, + { + "id": "35cb86dd-f328-429e-a973-12fd7218b696", + "type": "app-edge", + "sourceNodeId": "b931efe5-5b66-46e0-ae3b-0160cb18eeb5", + "targetNodeId": "fc60863a-dec2-4854-9e5a-7a44b7187a2b", + "startPoint": { + "x": 1000, + "y": 3210 + }, + "endPoint": { + "x": 1200, + "y": 3210 + }, + "properties": { + + }, + "pointsList": [ + { + "x": 1000, + "y": 3210 + }, + { + "x": 1110, + "y": 3210 + }, + { + "x": 1090, + "y": 3210 + }, + { + "x": 1200, + "y": 3210 + } + ], + "sourceAnchorId": "b931efe5-5b66-46e0-ae3b-0160cb18eeb5_right", + "targetAnchorId": "fc60863a-dec2-4854-9e5a-7a44b7187a2b_left" + }, + { + "id": "e8f6cfe6-7e48-41cd-abd3-abfb5304d0d8", + "type": "app-edge", + "sourceNodeId": "fc60863a-dec2-4854-9e5a-7a44b7187a2b", + "targetNodeId": "4ffe1086-25df-4c85-b168-979b5bbf0a26", + "startPoint": { + "x": 1780, + "y": 3073.775 + }, + "endPoint": { + "x": 2010, + "y": 2480 + }, + "properties": { + + }, + "pointsList": [ + { + "x": 1780, + "y": 3073.775 + }, + { + "x": 1890, + "y": 3073.775 + }, + { + "x": 1900, + "y": 2480 + }, + { + "x": 2010, + "y": 2480 + } + ], + "sourceAnchorId": "fc60863a-dec2-4854-9e5a-7a44b7187a2b_1009_right", + "targetAnchorId": "4ffe1086-25df-4c85-b168-979b5bbf0a26_left" + }, + { + "id": "994ff325-6f7a-4ebc-b61b-10e15519d6d2", + "type": "app-edge", + "sourceNodeId": "fc60863a-dec2-4854-9e5a-7a44b7187a2b", + "targetNodeId": "f1f1ee18-5a02-46f6-b4e6-226253cdffbb", + "startPoint": { + "x": 1780, + "y": 3203 + }, + "endPoint": { + "x": 2000, + "y": 3200 + }, + "properties": { + + }, + "pointsList": [ + { + "x": 1780, + "y": 3203 + }, + { + "x": 1890, + "y": 3203 + }, + { + "x": 1890, + "y": 3200 + }, + { + "x": 2000, + "y": 3200 + } + ], + "sourceAnchorId": "fc60863a-dec2-4854-9e5a-7a44b7187a2b_4908_right", + "targetAnchorId": "f1f1ee18-5a02-46f6-b4e6-226253cdffbb_left" + }, + { + "id": "19270caf-bb9f-4ba7-9bf8-200aa70fecd5", + "type": "app-edge", + "sourceNodeId": "fc60863a-dec2-4854-9e5a-7a44b7187a2b", + "targetNodeId": "309d0eef-c597-46b5-8d51-b9a28aaef4c7", + "startPoint": { + "x": 1780, + "y": 3293.6124999999997 + }, + "endPoint": { + "x": 2000, + "y": 3970 + }, + "properties": { + + }, + "pointsList": [ + { + "x": 1780, + "y": 3293.6124999999997 + }, + { + "x": 1890, + "y": 3293.6124999999997 + }, + { + "x": 1890, + "y": 3970 + }, + { + "x": 2000, + "y": 3970 + } + ], + "sourceAnchorId": "fc60863a-dec2-4854-9e5a-7a44b7187a2b_161_right", + "targetAnchorId": "309d0eef-c597-46b5-8d51-b9a28aaef4c7_left" + } + ] +} \ No newline at end of file diff --git a/apps/application/flow/i_step_node.py b/apps/application/flow/i_step_node.py new file mode 100644 index 00000000000..fcead7a40ad --- /dev/null +++ b/apps/application/flow/i_step_node.py @@ -0,0 +1,256 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: i_step_node.py + @date:2024/6/3 14:57 + @desc: +""" +import time +import uuid +from abc import abstractmethod +from hashlib import sha1 +from typing import Type, Dict, List + +from django.core import cache +from django.db.models import QuerySet +from rest_framework import serializers +from rest_framework.exceptions import ValidationError, ErrorDetail + +from application.flow.common import Answer, NodeChunk +from application.models import ChatRecord +from application.models.api_key_model import ApplicationPublicAccessClient +from common.constants.authentication_type import AuthenticationType +from common.field.common import InstanceField +from common.util.field_message import ErrMessage + +chat_cache = cache.caches['chat_cache'] + + +def write_context(step_variable: Dict, global_variable: Dict, node, workflow): + if step_variable is not None: + for key in step_variable: + node.context[key] = step_variable[key] + if workflow.is_result(node, NodeResult(step_variable, global_variable)) and 'answer' in step_variable: + answer = step_variable['answer'] + yield answer + node.answer_text = answer + if global_variable is not None: + for key in global_variable: + workflow.context[key] = global_variable[key] + node.context['run_time'] = time.time() - node.context['start_time'] + + +def is_interrupt(node, step_variable: Dict, global_variable: Dict): + return node.type == 'form-node' and not node.context.get('is_submit', False) + + +class WorkFlowPostHandler: + def __init__(self, chat_info, client_id, client_type): + self.chat_info = chat_info + self.client_id = client_id + self.client_type = client_type + + def handler(self, chat_id, + chat_record_id, + answer, + workflow): + question = workflow.params['question'] + details = workflow.get_runtime_details() + message_tokens = sum([row.get('message_tokens') for row in details.values() if + 'message_tokens' in row and row.get('message_tokens') is not None]) + answer_tokens = sum([row.get('answer_tokens') for row in details.values() if + 'answer_tokens' in row and row.get('answer_tokens') is not None]) + answer_text_list = workflow.get_answer_text_list() + answer_text = '\n\n'.join( + '\n\n'.join([a.get('content') for a in answer]) for answer in + answer_text_list) + if workflow.chat_record is not None: + chat_record = workflow.chat_record + chat_record.answer_text = answer_text + chat_record.details = details + chat_record.message_tokens = message_tokens + chat_record.answer_tokens = answer_tokens + chat_record.answer_text_list = answer_text_list + chat_record.run_time = time.time() - workflow.context['start_time'] + else: + chat_record = ChatRecord(id=chat_record_id, + chat_id=chat_id, + problem_text=question, + answer_text=answer_text, + details=details, + message_tokens=message_tokens, + answer_tokens=answer_tokens, + answer_text_list=answer_text_list, + run_time=time.time() - workflow.context['start_time'], + index=0) + asker = workflow.context.get('asker', None) + self.chat_info.append_chat_record(chat_record, self.client_id, asker) + # 重新设置缓存 + chat_cache.set(chat_id, + self.chat_info, timeout=60 * 30) + if self.client_type == AuthenticationType.APPLICATION_ACCESS_TOKEN.value: + application_public_access_client = (QuerySet(ApplicationPublicAccessClient) + .filter(client_id=self.client_id, + application_id=self.chat_info.application.id).first()) + if application_public_access_client is not None: + application_public_access_client.access_num = application_public_access_client.access_num + 1 + application_public_access_client.intraday_access_num = application_public_access_client.intraday_access_num + 1 + application_public_access_client.save() + + +class NodeResult: + def __init__(self, node_variable: Dict, workflow_variable: Dict, + _write_context=write_context, _is_interrupt=is_interrupt): + self._write_context = _write_context + self.node_variable = node_variable + self.workflow_variable = workflow_variable + self._is_interrupt = _is_interrupt + + def write_context(self, node, workflow): + return self._write_context(self.node_variable, self.workflow_variable, node, workflow) + + def is_assertion_result(self): + return 'branch_id' in self.node_variable + + def is_interrupt_exec(self, current_node): + """ + 是否中断执行 + @param current_node: + @return: + """ + return self._is_interrupt(current_node, self.node_variable, self.workflow_variable) + + +class ReferenceAddressSerializer(serializers.Serializer): + node_id = serializers.CharField(required=True, error_messages=ErrMessage.char("节点id")) + fields = serializers.ListField( + child=serializers.CharField(required=True, error_messages=ErrMessage.char("节点字段")), required=True, + error_messages=ErrMessage.list("节点字段数组")) + + +class FlowParamsSerializer(serializers.Serializer): + # 历史对答 + history_chat_record = serializers.ListField(child=InstanceField(model_type=ChatRecord, required=True), + error_messages=ErrMessage.list("历史对答")) + + question = serializers.CharField(required=True, error_messages=ErrMessage.list("用户问题")) + + chat_id = serializers.CharField(required=True, error_messages=ErrMessage.list("对话id")) + + chat_record_id = serializers.CharField(required=True, error_messages=ErrMessage.char("对话记录id")) + + stream = serializers.BooleanField(required=True, error_messages=ErrMessage.boolean("流式输出")) + + client_id = serializers.CharField(required=False, error_messages=ErrMessage.char("客户端id")) + + client_type = serializers.CharField(required=False, error_messages=ErrMessage.char("客户端类型")) + + user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("用户id")) + re_chat = serializers.BooleanField(required=True, error_messages=ErrMessage.boolean("换个答案")) + + +class INode: + view_type = 'many_view' + + @abstractmethod + def save_context(self, details, workflow_manage): + pass + + def get_answer_list(self) -> List[Answer] | None: + if self.answer_text is None: + return None + reasoning_content_enable = self.context.get('model_setting', {}).get('reasoning_content_enable', False) + return [ + Answer(self.answer_text, self.view_type, self.runtime_node_id, self.workflow_params['chat_record_id'], {}, + self.runtime_node_id, self.context.get('reasoning_content', '') if reasoning_content_enable else '')] + + def __init__(self, node, workflow_params, workflow_manage, up_node_id_list=None, + get_node_params=lambda node: node.properties.get('node_data')): + # 当前步骤上下文,用于存储当前步骤信息 + self.status = 200 + self.err_message = '' + self.node = node + self.node_params = get_node_params(node) + self.workflow_params = workflow_params + self.workflow_manage = workflow_manage + self.node_params_serializer = None + self.flow_params_serializer = None + self.context = {} + self.answer_text = None + self.id = node.id + if up_node_id_list is None: + up_node_id_list = [] + self.up_node_id_list = up_node_id_list + self.node_chunk = NodeChunk() + self.runtime_node_id = sha1(uuid.NAMESPACE_DNS.bytes + bytes(str(uuid.uuid5(uuid.NAMESPACE_DNS, + "".join([*sorted(up_node_id_list), + node.id]))), + "utf-8")).hexdigest() + + def valid_args(self, node_params, flow_params): + flow_params_serializer_class = self.get_flow_params_serializer_class() + node_params_serializer_class = self.get_node_params_serializer_class() + if flow_params_serializer_class is not None and flow_params is not None: + self.flow_params_serializer = flow_params_serializer_class(data=flow_params) + self.flow_params_serializer.is_valid(raise_exception=True) + if node_params_serializer_class is not None: + self.node_params_serializer = node_params_serializer_class(data=node_params) + self.node_params_serializer.is_valid(raise_exception=True) + if self.node.properties.get('status', 200) != 200: + raise ValidationError(ErrorDetail(f'节点{self.node.properties.get("stepName")} 不可用')) + + def get_reference_field(self, fields: List[str]): + return self.get_field(self.context, fields) + + @staticmethod + def get_field(obj, fields: List[str]): + for field in fields: + value = obj.get(field) + if value is None: + return None + else: + obj = value + return obj + + @abstractmethod + def get_node_params_serializer_class(self) -> Type[serializers.Serializer]: + pass + + def get_flow_params_serializer_class(self) -> Type[serializers.Serializer]: + return FlowParamsSerializer + + def get_write_error_context(self, e): + self.status = 500 + self.answer_text = str(e) + self.err_message = str(e) + self.context['run_time'] = time.time() - self.context['start_time'] + + def write_error_context(answer, status=200): + pass + + return write_error_context + + def run(self) -> NodeResult: + """ + :return: 执行结果 + """ + start_time = time.time() + self.context['start_time'] = start_time + result = self._run() + self.context['run_time'] = time.time() - start_time + return result + + def _run(self): + result = self.execute() + return result + + def execute(self, **kwargs) -> NodeResult: + pass + + def get_details(self, index: int, **kwargs): + """ + 运行详情 + :return: 步骤详情 + """ + return {} diff --git a/apps/application/flow/step_node/__init__.py b/apps/application/flow/step_node/__init__.py new file mode 100644 index 00000000000..0ce1d5fedd1 --- /dev/null +++ b/apps/application/flow/step_node/__init__.py @@ -0,0 +1,42 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: __init__.py.py + @date:2024/6/7 14:43 + @desc: +""" +from .ai_chat_step_node import * +from .application_node import BaseApplicationNode +from .condition_node import * +from .direct_reply_node import * +from .form_node import * +from .function_lib_node import * +from .function_node import * +from .question_node import * +from .reranker_node import * + +from .document_extract_node import * +from .image_understand_step_node import * +from .image_generate_step_node import * + +from .search_dataset_node import * +from .speech_to_text_step_node import BaseSpeechToTextNode +from .start_node import * +from .text_to_speech_step_node.impl.base_text_to_speech_node import BaseTextToSpeechNode +from .variable_assign_node import BaseVariableAssignNode +from .mcp_node import BaseMcpNode + +node_list = [BaseStartStepNode, BaseChatNode, BaseSearchDatasetNode, BaseQuestionNode, + BaseConditionNode, BaseReplyNode, + BaseFunctionNodeNode, BaseFunctionLibNodeNode, BaseRerankerNode, BaseApplicationNode, + BaseDocumentExtractNode, + BaseImageUnderstandNode, BaseFormNode, BaseSpeechToTextNode, BaseTextToSpeechNode, + BaseImageGenerateNode, BaseVariableAssignNode, BaseMcpNode] + + +def get_node(node_type): + find_list = [node for node in node_list if node.type == node_type] + if len(find_list) > 0: + return find_list[0] + return None diff --git a/apps/application/flow/step_node/ai_chat_step_node/__init__.py b/apps/application/flow/step_node/ai_chat_step_node/__init__.py new file mode 100644 index 00000000000..1929ae2af49 --- /dev/null +++ b/apps/application/flow/step_node/ai_chat_step_node/__init__.py @@ -0,0 +1,9 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: __init__.py + @date:2024/6/11 15:29 + @desc: +""" +from .impl import * diff --git a/apps/application/flow/step_node/ai_chat_step_node/i_chat_node.py b/apps/application/flow/step_node/ai_chat_step_node/i_chat_node.py new file mode 100644 index 00000000000..a83d2ef5771 --- /dev/null +++ b/apps/application/flow/step_node/ai_chat_step_node/i_chat_node.py @@ -0,0 +1,58 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: i_chat_node.py + @date:2024/6/4 13:58 + @desc: +""" +from typing import Type + +from django.utils.translation import gettext_lazy as _ +from rest_framework import serializers + +from application.flow.i_step_node import INode, NodeResult +from common.util.field_message import ErrMessage + + +class ChatNodeSerializer(serializers.Serializer): + model_id = serializers.CharField(required=True, error_messages=ErrMessage.char(_("Model id"))) + system = serializers.CharField(required=False, allow_blank=True, allow_null=True, + error_messages=ErrMessage.char(_("Role Setting"))) + prompt = serializers.CharField(required=True, error_messages=ErrMessage.char(_("Prompt word"))) + # 多轮对话数量 + dialogue_number = serializers.IntegerField(required=True, error_messages=ErrMessage.integer( + _("Number of multi-round conversations"))) + + is_result = serializers.BooleanField(required=False, + error_messages=ErrMessage.boolean(_('Whether to return content'))) + + model_params_setting = serializers.DictField(required=False, + error_messages=ErrMessage.dict(_("Model parameter settings"))) + model_setting = serializers.DictField(required=False, + error_messages=ErrMessage.dict('Model settings')) + dialogue_type = serializers.CharField(required=False, allow_blank=True, allow_null=True, + error_messages=ErrMessage.char(_("Context Type"))) + mcp_enable = serializers.BooleanField(required=False, + error_messages=ErrMessage.boolean(_("Whether to enable MCP"))) + mcp_servers = serializers.JSONField(required=False, error_messages=ErrMessage.list(_("MCP Server"))) + + +class IChatNode(INode): + type = 'ai-chat-node' + + def get_node_params_serializer_class(self) -> Type[serializers.Serializer]: + return ChatNodeSerializer + + def _run(self): + return self.execute(**self.node_params_serializer.data, **self.flow_params_serializer.data) + + def execute(self, model_id, system, prompt, dialogue_number, history_chat_record, stream, chat_id, + chat_record_id, + model_params_setting=None, + dialogue_type=None, + model_setting=None, + mcp_enable=False, + mcp_servers=None, + **kwargs) -> NodeResult: + pass diff --git a/apps/application/flow/step_node/ai_chat_step_node/impl/__init__.py b/apps/application/flow/step_node/ai_chat_step_node/impl/__init__.py new file mode 100644 index 00000000000..79051a999fb --- /dev/null +++ b/apps/application/flow/step_node/ai_chat_step_node/impl/__init__.py @@ -0,0 +1,9 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: __init__.py + @date:2024/6/11 15:34 + @desc: +""" +from .base_chat_node import BaseChatNode diff --git a/apps/application/flow/step_node/ai_chat_step_node/impl/base_chat_node.py b/apps/application/flow/step_node/ai_chat_step_node/impl/base_chat_node.py new file mode 100644 index 00000000000..8d576d416ad --- /dev/null +++ b/apps/application/flow/step_node/ai_chat_step_node/impl/base_chat_node.py @@ -0,0 +1,288 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: base_question_node.py + @date:2024/6/4 14:30 + @desc: +""" +import asyncio +import json +import re +import time +from functools import reduce +from types import AsyncGeneratorType +from typing import List, Dict + +from django.db.models import QuerySet +from langchain.schema import HumanMessage, SystemMessage +from langchain_core.messages import BaseMessage, AIMessage, AIMessageChunk, ToolMessage +from langchain_mcp_adapters.client import MultiServerMCPClient +from langgraph.prebuilt import create_react_agent + +from application.flow.i_step_node import NodeResult, INode +from application.flow.step_node.ai_chat_step_node.i_chat_node import IChatNode +from application.flow.tools import Reasoning +from setting.models import Model +from setting.models_provider import get_model_credential +from setting.models_provider.tools import get_model_instance_by_model_user_id + +tool_message_template = """ +
+ + Called MCP Tool: %s + + +```json +%s +``` +
+ +""" + + +def _write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow, answer: str, + reasoning_content: str): + chat_model = node_variable.get('chat_model') + message_tokens = chat_model.get_num_tokens_from_messages(node_variable.get('message_list')) + answer_tokens = chat_model.get_num_tokens(answer) + node.context['message_tokens'] = message_tokens + node.context['answer_tokens'] = answer_tokens + node.context['answer'] = answer + node.context['history_message'] = node_variable['history_message'] + node.context['question'] = node_variable['question'] + node.context['run_time'] = time.time() - node.context['start_time'] + node.context['reasoning_content'] = reasoning_content + if workflow.is_result(node, NodeResult(node_variable, workflow_variable)): + node.answer_text = answer + + +def write_context_stream(node_variable: Dict, workflow_variable: Dict, node: INode, workflow): + """ + 写入上下文数据 (流式) + @param node_variable: 节点数据 + @param workflow_variable: 全局数据 + @param node: 节点 + @param workflow: 工作流管理器 + """ + response = node_variable.get('result') + answer = '' + reasoning_content = '' + model_setting = node.context.get('model_setting', + {'reasoning_content_enable': False, 'reasoning_content_end': '', + 'reasoning_content_start': ''}) + reasoning = Reasoning(model_setting.get('reasoning_content_start', ''), + model_setting.get('reasoning_content_end', '')) + response_reasoning_content = False + + for chunk in response: + reasoning_chunk = reasoning.get_reasoning_content(chunk) + content_chunk = reasoning_chunk.get('content') + if 'reasoning_content' in chunk.additional_kwargs: + response_reasoning_content = True + reasoning_content_chunk = chunk.additional_kwargs.get('reasoning_content', '') + else: + reasoning_content_chunk = reasoning_chunk.get('reasoning_content') + answer += content_chunk + if reasoning_content_chunk is None: + reasoning_content_chunk = '' + reasoning_content += reasoning_content_chunk + yield {'content': content_chunk, + 'reasoning_content': reasoning_content_chunk if model_setting.get('reasoning_content_enable', + False) else ''} + + reasoning_chunk = reasoning.get_end_reasoning_content() + answer += reasoning_chunk.get('content') + reasoning_content_chunk = "" + if not response_reasoning_content: + reasoning_content_chunk = reasoning_chunk.get( + 'reasoning_content') + yield {'content': reasoning_chunk.get('content'), + 'reasoning_content': reasoning_content_chunk if model_setting.get('reasoning_content_enable', + False) else ''} + _write_context(node_variable, workflow_variable, node, workflow, answer, reasoning_content) + + +async def _yield_mcp_response(chat_model, message_list, mcp_servers): + async with MultiServerMCPClient(json.loads(mcp_servers)) as client: + agent = create_react_agent(chat_model, client.get_tools()) + response = agent.astream({"messages": message_list}, stream_mode='messages') + async for chunk in response: + if isinstance(chunk[0], ToolMessage): + content = tool_message_template % (chunk[0].name, chunk[0].content) + chunk[0].content = content + yield chunk[0] + if isinstance(chunk[0], AIMessageChunk): + yield chunk[0] + + +def mcp_response_generator(chat_model, message_list, mcp_servers): + loop = asyncio.new_event_loop() + try: + async_gen = _yield_mcp_response(chat_model, message_list, mcp_servers) + while True: + try: + chunk = loop.run_until_complete(anext_async(async_gen)) + yield chunk + except StopAsyncIteration: + break + except Exception as e: + print(f'exception: {e}') + finally: + loop.close() + + +async def anext_async(agen): + return await agen.__anext__() + + +def write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow): + """ + 写入上下文数据 + @param node_variable: 节点数据 + @param workflow_variable: 全局数据 + @param node: 节点实例对象 + @param workflow: 工作流管理器 + """ + response = node_variable.get('result') + model_setting = node.context.get('model_setting', + {'reasoning_content_enable': False, 'reasoning_content_end': '', + 'reasoning_content_start': ''}) + reasoning = Reasoning(model_setting.get('reasoning_content_start'), model_setting.get('reasoning_content_end')) + reasoning_result = reasoning.get_reasoning_content(response) + reasoning_result_end = reasoning.get_end_reasoning_content() + content = reasoning_result.get('content') + reasoning_result_end.get('content') + if 'reasoning_content' in response.response_metadata: + reasoning_content = response.response_metadata.get('reasoning_content', '') + else: + reasoning_content = reasoning_result.get('reasoning_content') + reasoning_result_end.get('reasoning_content') + _write_context(node_variable, workflow_variable, node, workflow, content, reasoning_content) + + +def get_default_model_params_setting(model_id): + model = QuerySet(Model).filter(id=model_id).first() + credential = get_model_credential(model.provider, model.model_type, model.model_name) + model_params_setting = credential.get_model_params_setting_form( + model.model_name).get_default_form_data() + return model_params_setting + + +def get_node_message(chat_record, runtime_node_id): + node_details = chat_record.get_node_details_runtime_node_id(runtime_node_id) + if node_details is None: + return [] + return [HumanMessage(node_details.get('question')), AIMessage(node_details.get('answer'))] + + +def get_workflow_message(chat_record): + return [chat_record.get_human_message(), chat_record.get_ai_message()] + + +def get_message(chat_record, dialogue_type, runtime_node_id): + return get_node_message(chat_record, runtime_node_id) if dialogue_type == 'NODE' else get_workflow_message( + chat_record) + + +class BaseChatNode(IChatNode): + def save_context(self, details, workflow_manage): + self.context['answer'] = details.get('answer') + self.context['question'] = details.get('question') + self.context['reasoning_content'] = details.get('reasoning_content') + if self.node_params.get('is_result', False): + self.answer_text = details.get('answer') + + def execute(self, model_id, system, prompt, dialogue_number, history_chat_record, stream, chat_id, chat_record_id, + model_params_setting=None, + dialogue_type=None, + model_setting=None, + mcp_enable=False, + mcp_servers=None, + **kwargs) -> NodeResult: + if dialogue_type is None: + dialogue_type = 'WORKFLOW' + + if model_params_setting is None: + model_params_setting = get_default_model_params_setting(model_id) + if model_setting is None: + model_setting = {'reasoning_content_enable': False, 'reasoning_content_end': '', + 'reasoning_content_start': ''} + self.context['model_setting'] = model_setting + chat_model = get_model_instance_by_model_user_id(model_id, self.flow_params_serializer.data.get('user_id'), + **model_params_setting) + history_message = self.get_history_message(history_chat_record, dialogue_number, dialogue_type, + self.runtime_node_id) + self.context['history_message'] = history_message + question = self.generate_prompt_question(prompt) + self.context['question'] = question.content + system = self.workflow_manage.generate_prompt(system) + self.context['system'] = system + message_list = self.generate_message_list(system, prompt, history_message) + self.context['message_list'] = message_list + + if mcp_enable and mcp_servers is not None: + r = mcp_response_generator(chat_model, message_list, mcp_servers) + return NodeResult( + {'result': r, 'chat_model': chat_model, 'message_list': message_list, + 'history_message': history_message, 'question': question.content}, {}, + _write_context=write_context_stream) + + if stream: + r = chat_model.stream(message_list) + return NodeResult({'result': r, 'chat_model': chat_model, 'message_list': message_list, + 'history_message': history_message, 'question': question.content}, {}, + _write_context=write_context_stream) + else: + r = chat_model.invoke(message_list) + return NodeResult({'result': r, 'chat_model': chat_model, 'message_list': message_list, + 'history_message': history_message, 'question': question.content}, {}, + _write_context=write_context) + + @staticmethod + def get_history_message(history_chat_record, dialogue_number, dialogue_type, runtime_node_id): + start_index = len(history_chat_record) - dialogue_number + history_message = reduce(lambda x, y: [*x, *y], [ + get_message(history_chat_record[index], dialogue_type, runtime_node_id) + for index in + range(start_index if start_index > 0 else 0, len(history_chat_record))], []) + for message in history_message: + if isinstance(message.content, str): + message.content = re.sub('[\d\D]*?<\/form_rander>', '', message.content) + return history_message + + def generate_prompt_question(self, prompt): + return HumanMessage(self.workflow_manage.generate_prompt(prompt)) + + def generate_message_list(self, system: str, prompt: str, history_message): + if system is not None and len(system) > 0: + return [SystemMessage(self.workflow_manage.generate_prompt(system)), *history_message, + HumanMessage(self.workflow_manage.generate_prompt(prompt))] + else: + return [*history_message, HumanMessage(self.workflow_manage.generate_prompt(prompt))] + + @staticmethod + def reset_message_list(message_list: List[BaseMessage], answer_text): + result = [{'role': 'user' if isinstance(message, HumanMessage) else 'ai', 'content': message.content} for + message + in + message_list] + result.append({'role': 'ai', 'content': answer_text}) + return result + + def get_details(self, index: int, **kwargs): + return { + 'name': self.node.properties.get('stepName'), + "index": index, + 'run_time': self.context.get('run_time'), + 'system': self.context.get('system'), + 'history_message': [{'content': message.content, 'role': message.type} for message in + (self.context.get('history_message') if self.context.get( + 'history_message') is not None else [])], + 'question': self.context.get('question'), + 'answer': self.context.get('answer'), + 'reasoning_content': self.context.get('reasoning_content'), + 'type': self.node.type, + 'message_tokens': self.context.get('message_tokens'), + 'answer_tokens': self.context.get('answer_tokens'), + 'status': self.status, + 'err_message': self.err_message + } diff --git a/apps/application/flow/step_node/application_node/__init__.py b/apps/application/flow/step_node/application_node/__init__.py new file mode 100644 index 00000000000..d1ea91ca7f8 --- /dev/null +++ b/apps/application/flow/step_node/application_node/__init__.py @@ -0,0 +1,2 @@ +# coding=utf-8 +from .impl import * diff --git a/apps/application/flow/step_node/application_node/i_application_node.py b/apps/application/flow/step_node/application_node/i_application_node.py new file mode 100644 index 00000000000..6394fa49c7b --- /dev/null +++ b/apps/application/flow/step_node/application_node/i_application_node.py @@ -0,0 +1,86 @@ +# coding=utf-8 +from typing import Type + +from rest_framework import serializers + +from application.flow.i_step_node import INode, NodeResult +from common.util.field_message import ErrMessage + +from django.utils.translation import gettext_lazy as _ + + +class ApplicationNodeSerializer(serializers.Serializer): + application_id = serializers.CharField(required=True, error_messages=ErrMessage.char(_("Application ID"))) + question_reference_address = serializers.ListField(required=True, + error_messages=ErrMessage.list(_("User Questions"))) + api_input_field_list = serializers.ListField(required=False, error_messages=ErrMessage.list(_("API Input Fields"))) + user_input_field_list = serializers.ListField(required=False, + error_messages=ErrMessage.uuid(_("User Input Fields"))) + image_list = serializers.ListField(required=False, error_messages=ErrMessage.list(_("picture"))) + document_list = serializers.ListField(required=False, error_messages=ErrMessage.list(_("document"))) + audio_list = serializers.ListField(required=False, error_messages=ErrMessage.list(_("Audio"))) + child_node = serializers.DictField(required=False, allow_null=True, + error_messages=ErrMessage.dict(_("Child Nodes"))) + node_data = serializers.DictField(required=False, allow_null=True, error_messages=ErrMessage.dict(_("Form Data"))) + + +class IApplicationNode(INode): + type = 'application-node' + + def get_node_params_serializer_class(self) -> Type[serializers.Serializer]: + return ApplicationNodeSerializer + + def _run(self): + question = self.workflow_manage.get_reference_field( + self.node_params_serializer.data.get('question_reference_address')[0], + self.node_params_serializer.data.get('question_reference_address')[1:]) + kwargs = {} + for api_input_field in self.node_params_serializer.data.get('api_input_field_list', []): + value = api_input_field.get('value', [''])[0] if api_input_field.get('value') else '' + kwargs[api_input_field['variable']] = self.workflow_manage.get_reference_field(value, + api_input_field['value'][ + 1:]) if value != '' else '' + + for user_input_field in self.node_params_serializer.data.get('user_input_field_list', []): + value = user_input_field.get('value', [''])[0] if user_input_field.get('value') else '' + kwargs[user_input_field['field']] = self.workflow_manage.get_reference_field(value, + user_input_field['value'][ + 1:]) if value != '' else '' + # 判断是否包含这个属性 + app_document_list = self.node_params_serializer.data.get('document_list', []) + if app_document_list and len(app_document_list) > 0: + app_document_list = self.workflow_manage.get_reference_field( + app_document_list[0], + app_document_list[1:]) + for document in app_document_list: + if 'file_id' not in document: + raise ValueError( + _("Parameter value error: The uploaded document lacks file_id, and the document upload fails")) + app_image_list = self.node_params_serializer.data.get('image_list', []) + if app_image_list and len(app_image_list) > 0: + app_image_list = self.workflow_manage.get_reference_field( + app_image_list[0], + app_image_list[1:]) + for image in app_image_list: + if 'file_id' not in image: + raise ValueError( + _("Parameter value error: The uploaded image lacks file_id, and the image upload fails")) + + app_audio_list = self.node_params_serializer.data.get('audio_list', []) + if app_audio_list and len(app_audio_list) > 0: + app_audio_list = self.workflow_manage.get_reference_field( + app_audio_list[0], + app_audio_list[1:]) + for audio in app_audio_list: + if 'file_id' not in audio: + raise ValueError( + _("Parameter value error: The uploaded audio lacks file_id, and the audio upload fails.")) + return self.execute(**self.node_params_serializer.data, **self.flow_params_serializer.data, + app_document_list=app_document_list, app_image_list=app_image_list, + app_audio_list=app_audio_list, + message=str(question), **kwargs) + + def execute(self, application_id, message, chat_id, chat_record_id, stream, re_chat, client_id, client_type, + app_document_list=None, app_image_list=None, app_audio_list=None, child_node=None, node_data=None, + **kwargs) -> NodeResult: + pass diff --git a/apps/application/flow/step_node/application_node/impl/__init__.py b/apps/application/flow/step_node/application_node/impl/__init__.py new file mode 100644 index 00000000000..e31a8d885cd --- /dev/null +++ b/apps/application/flow/step_node/application_node/impl/__init__.py @@ -0,0 +1,2 @@ +# coding=utf-8 +from .base_application_node import BaseApplicationNode diff --git a/apps/application/flow/step_node/application_node/impl/base_application_node.py b/apps/application/flow/step_node/application_node/impl/base_application_node.py new file mode 100644 index 00000000000..95445f45612 --- /dev/null +++ b/apps/application/flow/step_node/application_node/impl/base_application_node.py @@ -0,0 +1,267 @@ +# coding=utf-8 +import json +import re +import time +import uuid +from typing import Dict, List + +from application.flow.common import Answer +from application.flow.i_step_node import NodeResult, INode +from application.flow.step_node.application_node.i_application_node import IApplicationNode +from application.models import Chat + + +def string_to_uuid(input_str): + return str(uuid.uuid5(uuid.NAMESPACE_DNS, input_str)) + + +def _is_interrupt_exec(node, node_variable: Dict, workflow_variable: Dict): + return node_variable.get('is_interrupt_exec', False) + + +def _write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow, answer: str, + reasoning_content: str): + result = node_variable.get('result') + node.context['application_node_dict'] = node_variable.get('application_node_dict') + node.context['node_dict'] = node_variable.get('node_dict', {}) + node.context['is_interrupt_exec'] = node_variable.get('is_interrupt_exec') + node.context['message_tokens'] = result.get('usage', {}).get('prompt_tokens', 0) + node.context['answer_tokens'] = result.get('usage', {}).get('completion_tokens', 0) + node.context['answer'] = answer + node.context['result'] = answer + node.context['reasoning_content'] = reasoning_content + node.context['question'] = node_variable['question'] + node.context['run_time'] = time.time() - node.context['start_time'] + if workflow.is_result(node, NodeResult(node_variable, workflow_variable)): + node.answer_text = answer + + +def write_context_stream(node_variable: Dict, workflow_variable: Dict, node: INode, workflow): + """ + 写入上下文数据 (流式) + @param node_variable: 节点数据 + @param workflow_variable: 全局数据 + @param node: 节点 + @param workflow: 工作流管理器 + """ + response = node_variable.get('result') + answer = '' + reasoning_content = '' + usage = {} + node_child_node = {} + application_node_dict = node.context.get('application_node_dict', {}) + is_interrupt_exec = False + for chunk in response: + # 先把流转成字符串 + response_content = chunk.decode('utf-8')[6:] + response_content = json.loads(response_content) + content = response_content.get('content', '') + runtime_node_id = response_content.get('runtime_node_id', '') + chat_record_id = response_content.get('chat_record_id', '') + child_node = response_content.get('child_node') + view_type = response_content.get('view_type') + node_type = response_content.get('node_type') + real_node_id = response_content.get('real_node_id') + node_is_end = response_content.get('node_is_end', False) + _reasoning_content = response_content.get('reasoning_content', '') + if node_type == 'form-node': + is_interrupt_exec = True + answer += content + reasoning_content += _reasoning_content + node_child_node = {'runtime_node_id': runtime_node_id, 'chat_record_id': chat_record_id, + 'child_node': child_node} + + if real_node_id is not None: + application_node = application_node_dict.get(real_node_id, None) + if application_node is None: + + application_node_dict[real_node_id] = {'content': content, + 'runtime_node_id': runtime_node_id, + 'chat_record_id': chat_record_id, + 'child_node': child_node, + 'index': len(application_node_dict), + 'view_type': view_type, + 'reasoning_content': _reasoning_content} + else: + application_node['content'] += content + application_node['reasoning_content'] += _reasoning_content + + yield {'content': content, + 'node_type': node_type, + 'runtime_node_id': runtime_node_id, 'chat_record_id': chat_record_id, + 'reasoning_content': _reasoning_content, + 'child_node': child_node, + 'real_node_id': real_node_id, + 'node_is_end': node_is_end, + 'view_type': view_type} + usage = response_content.get('usage', {}) + node_variable['result'] = {'usage': usage} + node_variable['is_interrupt_exec'] = is_interrupt_exec + node_variable['child_node'] = node_child_node + node_variable['application_node_dict'] = application_node_dict + _write_context(node_variable, workflow_variable, node, workflow, answer, reasoning_content) + + +def write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow): + """ + 写入上下文数据 + @param node_variable: 节点数据 + @param workflow_variable: 全局数据 + @param node: 节点实例对象 + @param workflow: 工作流管理器 + """ + response = node_variable.get('result', {}).get('data', {}) + node_variable['result'] = {'usage': {'completion_tokens': response.get('completion_tokens'), + 'prompt_tokens': response.get('prompt_tokens')}} + answer = response.get('content', '') or "抱歉,没有查找到相关内容,请重新描述您的问题或提供更多信息。" + reasoning_content = response.get('reasoning_content', '') + answer_list = response.get('answer_list', []) + node_variable['application_node_dict'] = {answer.get('real_node_id'): {**answer, 'index': index} for answer, index + in + zip(answer_list, range(len(answer_list)))} + _write_context(node_variable, workflow_variable, node, workflow, answer, reasoning_content) + + +def reset_application_node_dict(application_node_dict, runtime_node_id, node_data): + try: + if application_node_dict is None: + return + for key in application_node_dict: + application_node = application_node_dict[key] + if application_node.get('runtime_node_id') == runtime_node_id: + content: str = application_node.get('content') + match = re.search('.*?', content) + if match: + form_setting_str = match.group().replace('', '').replace('', '') + form_setting = json.loads(form_setting_str) + form_setting['is_submit'] = True + form_setting['form_data'] = node_data + value = f'{json.dumps(form_setting)}' + res = re.sub('.*?', + '${value}', content) + application_node['content'] = res.replace('${value}', value) + except Exception as e: + pass + + +class BaseApplicationNode(IApplicationNode): + def get_answer_list(self) -> List[Answer] | None: + if self.answer_text is None: + return None + application_node_dict = self.context.get('application_node_dict') + if application_node_dict is None or len(application_node_dict) == 0: + return [ + Answer(self.answer_text, self.view_type, self.runtime_node_id, self.workflow_params['chat_record_id'], + self.context.get('child_node'), self.runtime_node_id, '')] + else: + return [Answer(n.get('content'), n.get('view_type'), self.runtime_node_id, + self.workflow_params['chat_record_id'], {'runtime_node_id': n.get('runtime_node_id'), + 'chat_record_id': n.get('chat_record_id') + , 'child_node': n.get('child_node')}, n.get('real_node_id'), + n.get('reasoning_content', '')) + for n in + sorted(application_node_dict.values(), key=lambda item: item.get('index'))] + + def save_context(self, details, workflow_manage): + self.context['answer'] = details.get('answer') + self.context['result'] = details.get('answer') + self.context['question'] = details.get('question') + self.context['type'] = details.get('type') + self.context['reasoning_content'] = details.get('reasoning_content') + if self.node_params.get('is_result', False): + self.answer_text = details.get('answer') + + def execute(self, application_id, message, chat_id, chat_record_id, stream, re_chat, client_id, client_type, + app_document_list=None, app_image_list=None, app_audio_list=None, child_node=None, node_data=None, + **kwargs) -> NodeResult: + from application.serializers.chat_message_serializers import ChatMessageSerializer + # 生成嵌入应用的chat_id + current_chat_id = string_to_uuid(chat_id + application_id) + Chat.objects.get_or_create(id=current_chat_id, defaults={ + 'application_id': application_id, + 'abstract': message[0:1024], + 'client_id': client_id, + }) + if app_document_list is None: + app_document_list = [] + if app_image_list is None: + app_image_list = [] + if app_audio_list is None: + app_audio_list = [] + runtime_node_id = None + record_id = None + child_node_value = None + if child_node is not None: + runtime_node_id = child_node.get('runtime_node_id') + record_id = child_node.get('chat_record_id') + child_node_value = child_node.get('child_node') + application_node_dict = self.context.get('application_node_dict') + reset_application_node_dict(application_node_dict, runtime_node_id, node_data) + + response = ChatMessageSerializer( + data={'chat_id': current_chat_id, 'message': message, + 're_chat': re_chat, + 'stream': stream, + 'application_id': application_id, + 'client_id': client_id, + 'client_type': client_type, + 'document_list': app_document_list, + 'image_list': app_image_list, + 'audio_list': app_audio_list, + 'runtime_node_id': runtime_node_id, + 'chat_record_id': record_id, + 'child_node': child_node_value, + 'node_data': node_data, + 'form_data': kwargs}).chat() + if response.status_code == 200: + if stream: + content_generator = response.streaming_content + return NodeResult({'result': content_generator, 'question': message}, {}, + _write_context=write_context_stream, _is_interrupt=_is_interrupt_exec) + else: + data = json.loads(response.content) + return NodeResult({'result': data, 'question': message}, {}, + _write_context=write_context, _is_interrupt=_is_interrupt_exec) + + def get_details(self, index: int, **kwargs): + global_fields = [] + for api_input_field in self.node_params_serializer.data.get('api_input_field_list', []): + value = api_input_field.get('value', [''])[0] if api_input_field.get('value') else '' + global_fields.append({ + 'label': api_input_field['variable'], + 'key': api_input_field['variable'], + 'value': self.workflow_manage.get_reference_field( + value, + api_input_field['value'][1:] + ) if value != '' else '' + }) + + for user_input_field in self.node_params_serializer.data.get('user_input_field_list', []): + value = user_input_field.get('value', [''])[0] if user_input_field.get('value') else '' + global_fields.append({ + 'label': user_input_field['label'], + 'key': user_input_field['field'], + 'value': self.workflow_manage.get_reference_field( + value, + user_input_field['value'][1:] + ) if value != '' else '' + }) + return { + 'name': self.node.properties.get('stepName'), + "index": index, + "info": self.node.properties.get('node_data'), + 'run_time': self.context.get('run_time'), + 'question': self.context.get('question'), + 'answer': self.context.get('answer'), + 'reasoning_content': self.context.get('reasoning_content'), + 'type': self.node.type, + 'message_tokens': self.context.get('message_tokens'), + 'answer_tokens': self.context.get('answer_tokens'), + 'status': self.status, + 'err_message': self.err_message, + 'global_fields': global_fields, + 'document_list': self.workflow_manage.document_list, + 'image_list': self.workflow_manage.image_list, + 'audio_list': self.workflow_manage.audio_list, + 'application_node_dict': self.context.get('application_node_dict') + } diff --git a/apps/application/flow/step_node/condition_node/__init__.py b/apps/application/flow/step_node/condition_node/__init__.py new file mode 100644 index 00000000000..57638504c9e --- /dev/null +++ b/apps/application/flow/step_node/condition_node/__init__.py @@ -0,0 +1,9 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: __init__.py.py + @date:2024/6/7 14:43 + @desc: +""" +from .impl import * diff --git a/apps/application/flow/step_node/condition_node/compare/__init__.py b/apps/application/flow/step_node/condition_node/compare/__init__.py new file mode 100644 index 00000000000..c015f6fea45 --- /dev/null +++ b/apps/application/flow/step_node/condition_node/compare/__init__.py @@ -0,0 +1,30 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: __init__.py.py + @date:2024/6/7 14:43 + @desc: +""" + +from .contain_compare import * +from .equal_compare import * +from .ge_compare import * +from .gt_compare import * +from .is_not_null_compare import * +from .is_not_true import IsNotTrueCompare +from .is_null_compare import * +from .is_true import IsTrueCompare +from .le_compare import * +from .len_equal_compare import * +from .len_ge_compare import * +from .len_gt_compare import * +from .len_le_compare import * +from .len_lt_compare import * +from .lt_compare import * +from .not_contain_compare import * + +compare_handle_list = [GECompare(), GTCompare(), ContainCompare(), EqualCompare(), LTCompare(), LECompare(), + LenLECompare(), LenGECompare(), LenEqualCompare(), LenGTCompare(), LenLTCompare(), + IsNullCompare(), + IsNotNullCompare(), NotContainCompare(), IsTrueCompare(), IsNotTrueCompare()] diff --git a/apps/application/flow/step_node/condition_node/compare/compare.py b/apps/application/flow/step_node/condition_node/compare/compare.py new file mode 100644 index 00000000000..6cbb4af0732 --- /dev/null +++ b/apps/application/flow/step_node/condition_node/compare/compare.py @@ -0,0 +1,20 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: compare.py + @date:2024/6/7 14:37 + @desc: +""" +from abc import abstractmethod +from typing import List + + +class Compare: + @abstractmethod + def support(self, node_id, fields: List[str], source_value, compare, target_value): + pass + + @abstractmethod + def compare(self, source_value, compare, target_value): + pass diff --git a/apps/application/flow/step_node/condition_node/compare/contain_compare.py b/apps/application/flow/step_node/condition_node/compare/contain_compare.py new file mode 100644 index 00000000000..6073131a54d --- /dev/null +++ b/apps/application/flow/step_node/condition_node/compare/contain_compare.py @@ -0,0 +1,23 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: contain_compare.py + @date:2024/6/11 10:02 + @desc: +""" +from typing import List + +from application.flow.step_node.condition_node.compare.compare import Compare + + +class ContainCompare(Compare): + + def support(self, node_id, fields: List[str], source_value, compare, target_value): + if compare == 'contain': + return True + + def compare(self, source_value, compare, target_value): + if isinstance(source_value, str): + return str(target_value) in source_value + return any([str(item) == str(target_value) for item in source_value]) diff --git a/apps/application/flow/step_node/condition_node/compare/equal_compare.py b/apps/application/flow/step_node/condition_node/compare/equal_compare.py new file mode 100644 index 00000000000..0061a82f6e6 --- /dev/null +++ b/apps/application/flow/step_node/condition_node/compare/equal_compare.py @@ -0,0 +1,21 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: equal_compare.py + @date:2024/6/7 14:44 + @desc: +""" +from typing import List + +from application.flow.step_node.condition_node.compare.compare import Compare + + +class EqualCompare(Compare): + + def support(self, node_id, fields: List[str], source_value, compare, target_value): + if compare == 'eq': + return True + + def compare(self, source_value, compare, target_value): + return str(source_value) == str(target_value) diff --git a/apps/application/flow/step_node/condition_node/compare/ge_compare.py b/apps/application/flow/step_node/condition_node/compare/ge_compare.py new file mode 100644 index 00000000000..d4e22cbd696 --- /dev/null +++ b/apps/application/flow/step_node/condition_node/compare/ge_compare.py @@ -0,0 +1,24 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: lt_compare.py + @date:2024/6/11 9:52 + @desc: 大于比较器 +""" +from typing import List + +from application.flow.step_node.condition_node.compare.compare import Compare + + +class GECompare(Compare): + + def support(self, node_id, fields: List[str], source_value, compare, target_value): + if compare == 'ge': + return True + + def compare(self, source_value, compare, target_value): + try: + return float(source_value) >= float(target_value) + except Exception as e: + return False diff --git a/apps/application/flow/step_node/condition_node/compare/gt_compare.py b/apps/application/flow/step_node/condition_node/compare/gt_compare.py new file mode 100644 index 00000000000..80942abb2f2 --- /dev/null +++ b/apps/application/flow/step_node/condition_node/compare/gt_compare.py @@ -0,0 +1,24 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: lt_compare.py + @date:2024/6/11 9:52 + @desc: 大于比较器 +""" +from typing import List + +from application.flow.step_node.condition_node.compare.compare import Compare + + +class GTCompare(Compare): + + def support(self, node_id, fields: List[str], source_value, compare, target_value): + if compare == 'gt': + return True + + def compare(self, source_value, compare, target_value): + try: + return float(source_value) > float(target_value) + except Exception as e: + return False diff --git a/apps/application/flow/step_node/condition_node/compare/is_not_null_compare.py b/apps/application/flow/step_node/condition_node/compare/is_not_null_compare.py new file mode 100644 index 00000000000..5dec267135b --- /dev/null +++ b/apps/application/flow/step_node/condition_node/compare/is_not_null_compare.py @@ -0,0 +1,21 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: is_not_null_compare.py + @date:2024/6/28 10:45 + @desc: +""" +from typing import List + +from application.flow.step_node.condition_node.compare import Compare + + +class IsNotNullCompare(Compare): + + def support(self, node_id, fields: List[str], source_value, compare, target_value): + if compare == 'is_not_null': + return True + + def compare(self, source_value, compare, target_value): + return source_value is not None and len(source_value) > 0 diff --git a/apps/application/flow/step_node/condition_node/compare/is_not_true.py b/apps/application/flow/step_node/condition_node/compare/is_not_true.py new file mode 100644 index 00000000000..f8a29f5a126 --- /dev/null +++ b/apps/application/flow/step_node/condition_node/compare/is_not_true.py @@ -0,0 +1,24 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: is_not_true.py + @date:2025/4/7 13:44 + @desc: +""" +from typing import List + +from application.flow.step_node.condition_node.compare import Compare + + +class IsNotTrueCompare(Compare): + + def support(self, node_id, fields: List[str], source_value, compare, target_value): + if compare == 'is_not_true': + return True + + def compare(self, source_value, compare, target_value): + try: + return source_value is False + except Exception as e: + return False diff --git a/apps/application/flow/step_node/condition_node/compare/is_null_compare.py b/apps/application/flow/step_node/condition_node/compare/is_null_compare.py new file mode 100644 index 00000000000..c463f3fda28 --- /dev/null +++ b/apps/application/flow/step_node/condition_node/compare/is_null_compare.py @@ -0,0 +1,21 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: is_null_compare.py + @date:2024/6/28 10:45 + @desc: +""" +from typing import List + +from application.flow.step_node.condition_node.compare import Compare + + +class IsNullCompare(Compare): + + def support(self, node_id, fields: List[str], source_value, compare, target_value): + if compare == 'is_null': + return True + + def compare(self, source_value, compare, target_value): + return source_value is None or len(source_value) == 0 diff --git a/apps/application/flow/step_node/condition_node/compare/is_true.py b/apps/application/flow/step_node/condition_node/compare/is_true.py new file mode 100644 index 00000000000..166e0993ac0 --- /dev/null +++ b/apps/application/flow/step_node/condition_node/compare/is_true.py @@ -0,0 +1,24 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: IsTrue.py + @date:2025/4/7 13:38 + @desc: +""" +from typing import List + +from application.flow.step_node.condition_node.compare import Compare + + +class IsTrueCompare(Compare): + + def support(self, node_id, fields: List[str], source_value, compare, target_value): + if compare == 'is_true': + return True + + def compare(self, source_value, compare, target_value): + try: + return source_value is True + except Exception as e: + return False diff --git a/apps/application/flow/step_node/condition_node/compare/le_compare.py b/apps/application/flow/step_node/condition_node/compare/le_compare.py new file mode 100644 index 00000000000..77a0bca0f5b --- /dev/null +++ b/apps/application/flow/step_node/condition_node/compare/le_compare.py @@ -0,0 +1,24 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: lt_compare.py + @date:2024/6/11 9:52 + @desc: 小于比较器 +""" +from typing import List + +from application.flow.step_node.condition_node.compare.compare import Compare + + +class LECompare(Compare): + + def support(self, node_id, fields: List[str], source_value, compare, target_value): + if compare == 'le': + return True + + def compare(self, source_value, compare, target_value): + try: + return float(source_value) <= float(target_value) + except Exception as e: + return False diff --git a/apps/application/flow/step_node/condition_node/compare/len_equal_compare.py b/apps/application/flow/step_node/condition_node/compare/len_equal_compare.py new file mode 100644 index 00000000000..f2b0764c551 --- /dev/null +++ b/apps/application/flow/step_node/condition_node/compare/len_equal_compare.py @@ -0,0 +1,24 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: equal_compare.py + @date:2024/6/7 14:44 + @desc: +""" +from typing import List + +from application.flow.step_node.condition_node.compare.compare import Compare + + +class LenEqualCompare(Compare): + + def support(self, node_id, fields: List[str], source_value, compare, target_value): + if compare == 'len_eq': + return True + + def compare(self, source_value, compare, target_value): + try: + return len(source_value) == int(target_value) + except Exception as e: + return False diff --git a/apps/application/flow/step_node/condition_node/compare/len_ge_compare.py b/apps/application/flow/step_node/condition_node/compare/len_ge_compare.py new file mode 100644 index 00000000000..87f11eb2cc5 --- /dev/null +++ b/apps/application/flow/step_node/condition_node/compare/len_ge_compare.py @@ -0,0 +1,24 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: lt_compare.py + @date:2024/6/11 9:52 + @desc: 大于比较器 +""" +from typing import List + +from application.flow.step_node.condition_node.compare.compare import Compare + + +class LenGECompare(Compare): + + def support(self, node_id, fields: List[str], source_value, compare, target_value): + if compare == 'len_ge': + return True + + def compare(self, source_value, compare, target_value): + try: + return len(source_value) >= int(target_value) + except Exception as e: + return False diff --git a/apps/application/flow/step_node/condition_node/compare/len_gt_compare.py b/apps/application/flow/step_node/condition_node/compare/len_gt_compare.py new file mode 100644 index 00000000000..0532d353d74 --- /dev/null +++ b/apps/application/flow/step_node/condition_node/compare/len_gt_compare.py @@ -0,0 +1,24 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: lt_compare.py + @date:2024/6/11 9:52 + @desc: 大于比较器 +""" +from typing import List + +from application.flow.step_node.condition_node.compare.compare import Compare + + +class LenGTCompare(Compare): + + def support(self, node_id, fields: List[str], source_value, compare, target_value): + if compare == 'len_gt': + return True + + def compare(self, source_value, compare, target_value): + try: + return len(source_value) > int(target_value) + except Exception as e: + return False diff --git a/apps/application/flow/step_node/condition_node/compare/len_le_compare.py b/apps/application/flow/step_node/condition_node/compare/len_le_compare.py new file mode 100644 index 00000000000..d315a754aa6 --- /dev/null +++ b/apps/application/flow/step_node/condition_node/compare/len_le_compare.py @@ -0,0 +1,24 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: lt_compare.py + @date:2024/6/11 9:52 + @desc: 小于比较器 +""" +from typing import List + +from application.flow.step_node.condition_node.compare.compare import Compare + + +class LenLECompare(Compare): + + def support(self, node_id, fields: List[str], source_value, compare, target_value): + if compare == 'len_le': + return True + + def compare(self, source_value, compare, target_value): + try: + return len(source_value) <= int(target_value) + except Exception as e: + return False diff --git a/apps/application/flow/step_node/condition_node/compare/len_lt_compare.py b/apps/application/flow/step_node/condition_node/compare/len_lt_compare.py new file mode 100644 index 00000000000..c89638cd721 --- /dev/null +++ b/apps/application/flow/step_node/condition_node/compare/len_lt_compare.py @@ -0,0 +1,24 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: lt_compare.py + @date:2024/6/11 9:52 + @desc: 小于比较器 +""" +from typing import List + +from application.flow.step_node.condition_node.compare.compare import Compare + + +class LenLTCompare(Compare): + + def support(self, node_id, fields: List[str], source_value, compare, target_value): + if compare == 'len_lt': + return True + + def compare(self, source_value, compare, target_value): + try: + return len(source_value) < int(target_value) + except Exception as e: + return False diff --git a/apps/application/flow/step_node/condition_node/compare/lt_compare.py b/apps/application/flow/step_node/condition_node/compare/lt_compare.py new file mode 100644 index 00000000000..d2d5be74823 --- /dev/null +++ b/apps/application/flow/step_node/condition_node/compare/lt_compare.py @@ -0,0 +1,24 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: lt_compare.py + @date:2024/6/11 9:52 + @desc: 小于比较器 +""" +from typing import List + +from application.flow.step_node.condition_node.compare.compare import Compare + + +class LTCompare(Compare): + + def support(self, node_id, fields: List[str], source_value, compare, target_value): + if compare == 'lt': + return True + + def compare(self, source_value, compare, target_value): + try: + return float(source_value) < float(target_value) + except Exception as e: + return False diff --git a/apps/application/flow/step_node/condition_node/compare/not_contain_compare.py b/apps/application/flow/step_node/condition_node/compare/not_contain_compare.py new file mode 100644 index 00000000000..f95b237ddf6 --- /dev/null +++ b/apps/application/flow/step_node/condition_node/compare/not_contain_compare.py @@ -0,0 +1,23 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: contain_compare.py + @date:2024/6/11 10:02 + @desc: +""" +from typing import List + +from application.flow.step_node.condition_node.compare.compare import Compare + + +class NotContainCompare(Compare): + + def support(self, node_id, fields: List[str], source_value, compare, target_value): + if compare == 'not_contain': + return True + + def compare(self, source_value, compare, target_value): + if isinstance(source_value, str): + return str(target_value) not in source_value + return not any([str(item) == str(target_value) for item in source_value]) diff --git a/apps/application/flow/step_node/condition_node/i_condition_node.py b/apps/application/flow/step_node/condition_node/i_condition_node.py new file mode 100644 index 00000000000..a0e9814ff69 --- /dev/null +++ b/apps/application/flow/step_node/condition_node/i_condition_node.py @@ -0,0 +1,39 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: i_condition_node.py + @date:2024/6/7 9:54 + @desc: +""" +from typing import Type + +from django.utils.translation import gettext_lazy as _ +from rest_framework import serializers + +from application.flow.i_step_node import INode +from common.util.field_message import ErrMessage + + +class ConditionSerializer(serializers.Serializer): + compare = serializers.CharField(required=True, error_messages=ErrMessage.char(_("Comparator"))) + value = serializers.CharField(required=True, error_messages=ErrMessage.char(_("value"))) + field = serializers.ListField(required=True, error_messages=ErrMessage.char(_("Fields"))) + + +class ConditionBranchSerializer(serializers.Serializer): + id = serializers.CharField(required=True, error_messages=ErrMessage.char(_("Branch id"))) + type = serializers.CharField(required=True, error_messages=ErrMessage.char(_("Branch Type"))) + condition = serializers.CharField(required=True, error_messages=ErrMessage.char(_("Condition or|and"))) + conditions = ConditionSerializer(many=True) + + +class ConditionNodeParamsSerializer(serializers.Serializer): + branch = ConditionBranchSerializer(many=True) + + +class IConditionNode(INode): + def get_node_params_serializer_class(self) -> Type[serializers.Serializer]: + return ConditionNodeParamsSerializer + + type = 'condition-node' diff --git a/apps/application/flow/step_node/condition_node/impl/__init__.py b/apps/application/flow/step_node/condition_node/impl/__init__.py new file mode 100644 index 00000000000..c21cd3ebb37 --- /dev/null +++ b/apps/application/flow/step_node/condition_node/impl/__init__.py @@ -0,0 +1,9 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: __init__.py + @date:2024/6/11 15:35 + @desc: +""" +from .base_condition_node import BaseConditionNode diff --git a/apps/application/flow/step_node/condition_node/impl/base_condition_node.py b/apps/application/flow/step_node/condition_node/impl/base_condition_node.py new file mode 100644 index 00000000000..109029be211 --- /dev/null +++ b/apps/application/flow/step_node/condition_node/impl/base_condition_node.py @@ -0,0 +1,62 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: base_condition_node.py + @date:2024/6/7 11:29 + @desc: +""" +from typing import List + +from application.flow.i_step_node import NodeResult +from application.flow.step_node.condition_node.compare import compare_handle_list +from application.flow.step_node.condition_node.i_condition_node import IConditionNode + + +class BaseConditionNode(IConditionNode): + def save_context(self, details, workflow_manage): + self.context['branch_id'] = details.get('branch_id') + self.context['branch_name'] = details.get('branch_name') + + def execute(self, **kwargs) -> NodeResult: + branch_list = self.node_params_serializer.data['branch'] + branch = self._execute(branch_list) + r = NodeResult({'branch_id': branch.get('id'), 'branch_name': branch.get('type')}, {}) + return r + + def _execute(self, branch_list: List): + for branch in branch_list: + if self.branch_assertion(branch): + return branch + + def branch_assertion(self, branch): + condition_list = [self.assertion(row.get('field'), row.get('compare'), row.get('value')) for row in + branch.get('conditions')] + condition = branch.get('condition') + return all(condition_list) if condition == 'and' else any(condition_list) + + def assertion(self, field_list: List[str], compare: str, value): + try: + value = self.workflow_manage.generate_prompt(value) + except Exception as e: + pass + field_value = None + try: + field_value = self.workflow_manage.get_reference_field(field_list[0], field_list[1:]) + except Exception as e: + pass + for compare_handler in compare_handle_list: + if compare_handler.support(field_list[0], field_list[1:], field_value, compare, value): + return compare_handler.compare(field_value, compare, value) + + def get_details(self, index: int, **kwargs): + return { + 'name': self.node.properties.get('stepName'), + "index": index, + 'run_time': self.context.get('run_time'), + 'branch_id': self.context.get('branch_id'), + 'branch_name': self.context.get('branch_name'), + 'type': self.node.type, + 'status': self.status, + 'err_message': self.err_message + } diff --git a/apps/application/flow/step_node/direct_reply_node/__init__.py b/apps/application/flow/step_node/direct_reply_node/__init__.py new file mode 100644 index 00000000000..cf360f95685 --- /dev/null +++ b/apps/application/flow/step_node/direct_reply_node/__init__.py @@ -0,0 +1,9 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: __init__.py + @date:2024/6/11 17:50 + @desc: +""" +from .impl import * \ No newline at end of file diff --git a/apps/application/flow/step_node/direct_reply_node/i_reply_node.py b/apps/application/flow/step_node/direct_reply_node/i_reply_node.py new file mode 100644 index 00000000000..d60541b18fb --- /dev/null +++ b/apps/application/flow/step_node/direct_reply_node/i_reply_node.py @@ -0,0 +1,48 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: i_reply_node.py + @date:2024/6/11 16:25 + @desc: +""" +from typing import Type + +from rest_framework import serializers + +from application.flow.i_step_node import INode, NodeResult +from common.exception.app_exception import AppApiException +from common.util.field_message import ErrMessage +from django.utils.translation import gettext_lazy as _ + + +class ReplyNodeParamsSerializer(serializers.Serializer): + reply_type = serializers.CharField(required=True, error_messages=ErrMessage.char(_("Response Type"))) + fields = serializers.ListField(required=False, error_messages=ErrMessage.list(_("Reference Field"))) + content = serializers.CharField(required=False, allow_blank=True, allow_null=True, + error_messages=ErrMessage.char(_("Direct answer content"))) + is_result = serializers.BooleanField(required=False, error_messages=ErrMessage.boolean(_('Whether to return content'))) + + def is_valid(self, *, raise_exception=False): + super().is_valid(raise_exception=True) + if self.data.get('reply_type') == 'referencing': + if 'fields' not in self.data: + raise AppApiException(500, _("Reference field cannot be empty")) + if len(self.data.get('fields')) < 2: + raise AppApiException(500, _("Reference field error")) + else: + if 'content' not in self.data or self.data.get('content') is None: + raise AppApiException(500, _("Content cannot be empty")) + + +class IReplyNode(INode): + type = 'reply-node' + + def get_node_params_serializer_class(self) -> Type[serializers.Serializer]: + return ReplyNodeParamsSerializer + + def _run(self): + return self.execute(**self.node_params_serializer.data, **self.flow_params_serializer.data) + + def execute(self, reply_type, stream, fields=None, content=None, **kwargs) -> NodeResult: + pass diff --git a/apps/application/flow/step_node/direct_reply_node/impl/__init__.py b/apps/application/flow/step_node/direct_reply_node/impl/__init__.py new file mode 100644 index 00000000000..3307e90899e --- /dev/null +++ b/apps/application/flow/step_node/direct_reply_node/impl/__init__.py @@ -0,0 +1,9 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: __init__.py + @date:2024/6/11 17:49 + @desc: +""" +from .base_reply_node import * \ No newline at end of file diff --git a/apps/application/flow/step_node/direct_reply_node/impl/base_reply_node.py b/apps/application/flow/step_node/direct_reply_node/impl/base_reply_node.py new file mode 100644 index 00000000000..1d3115e4c67 --- /dev/null +++ b/apps/application/flow/step_node/direct_reply_node/impl/base_reply_node.py @@ -0,0 +1,45 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: base_reply_node.py + @date:2024/6/11 17:25 + @desc: +""" +from typing import List + +from application.flow.i_step_node import NodeResult +from application.flow.step_node.direct_reply_node.i_reply_node import IReplyNode + + +class BaseReplyNode(IReplyNode): + def save_context(self, details, workflow_manage): + self.context['answer'] = details.get('answer') + if self.node_params.get('is_result', False): + self.answer_text = details.get('answer') + + def execute(self, reply_type, stream, fields=None, content=None, **kwargs) -> NodeResult: + if reply_type == 'referencing': + result = self.get_reference_content(fields) + else: + result = self.generate_reply_content(content) + return NodeResult({'answer': result}, {}) + + def generate_reply_content(self, prompt): + return self.workflow_manage.generate_prompt(prompt) + + def get_reference_content(self, fields: List[str]): + return str(self.workflow_manage.get_reference_field( + fields[0], + fields[1:])) + + def get_details(self, index: int, **kwargs): + return { + 'name': self.node.properties.get('stepName'), + "index": index, + 'run_time': self.context.get('run_time'), + 'type': self.node.type, + 'answer': self.context.get('answer'), + 'status': self.status, + 'err_message': self.err_message + } diff --git a/apps/application/flow/step_node/document_extract_node/__init__.py b/apps/application/flow/step_node/document_extract_node/__init__.py new file mode 100644 index 00000000000..ce8f10f3e24 --- /dev/null +++ b/apps/application/flow/step_node/document_extract_node/__init__.py @@ -0,0 +1 @@ +from .impl import * \ No newline at end of file diff --git a/apps/application/flow/step_node/document_extract_node/i_document_extract_node.py b/apps/application/flow/step_node/document_extract_node/i_document_extract_node.py new file mode 100644 index 00000000000..93d2b5b987b --- /dev/null +++ b/apps/application/flow/step_node/document_extract_node/i_document_extract_node.py @@ -0,0 +1,28 @@ +# coding=utf-8 + +from typing import Type + +from django.utils.translation import gettext_lazy as _ +from rest_framework import serializers + +from application.flow.i_step_node import INode, NodeResult +from common.util.field_message import ErrMessage + + +class DocumentExtractNodeSerializer(serializers.Serializer): + document_list = serializers.ListField(required=False, error_messages=ErrMessage.list(_("document"))) + + +class IDocumentExtractNode(INode): + type = 'document-extract-node' + + def get_node_params_serializer_class(self) -> Type[serializers.Serializer]: + return DocumentExtractNodeSerializer + + def _run(self): + res = self.workflow_manage.get_reference_field(self.node_params_serializer.data.get('document_list')[0], + self.node_params_serializer.data.get('document_list')[1:]) + return self.execute(document=res, **self.flow_params_serializer.data) + + def execute(self, document, chat_id, **kwargs) -> NodeResult: + pass diff --git a/apps/application/flow/step_node/document_extract_node/impl/__init__.py b/apps/application/flow/step_node/document_extract_node/impl/__init__.py new file mode 100644 index 00000000000..cf9d55ecde8 --- /dev/null +++ b/apps/application/flow/step_node/document_extract_node/impl/__init__.py @@ -0,0 +1 @@ +from .base_document_extract_node import BaseDocumentExtractNode diff --git a/apps/application/flow/step_node/document_extract_node/impl/base_document_extract_node.py b/apps/application/flow/step_node/document_extract_node/impl/base_document_extract_node.py new file mode 100644 index 00000000000..6ddcb6e2fca --- /dev/null +++ b/apps/application/flow/step_node/document_extract_node/impl/base_document_extract_node.py @@ -0,0 +1,94 @@ +# coding=utf-8 +import io +import mimetypes + +from django.core.files.uploadedfile import InMemoryUploadedFile +from django.db.models import QuerySet + +from application.flow.i_step_node import NodeResult +from application.flow.step_node.document_extract_node.i_document_extract_node import IDocumentExtractNode +from dataset.models import File +from dataset.serializers.document_serializers import split_handles, parse_table_handle_list, FileBufferHandle +from dataset.serializers.file_serializers import FileSerializer + + +def bytes_to_uploaded_file(file_bytes, file_name="file.txt"): + content_type, _ = mimetypes.guess_type(file_name) + if content_type is None: + # 如果未能识别,设置为默认的二进制文件类型 + content_type = "application/octet-stream" + # 创建一个内存中的字节流对象 + file_stream = io.BytesIO(file_bytes) + + # 获取文件大小 + file_size = len(file_bytes) + + # 创建 InMemoryUploadedFile 对象 + uploaded_file = InMemoryUploadedFile( + file=file_stream, + field_name=None, + name=file_name, + content_type=content_type, + size=file_size, + charset=None, + ) + return uploaded_file + + +splitter = '\n`-----------------------------------`\n' + +class BaseDocumentExtractNode(IDocumentExtractNode): + def save_context(self, details, workflow_manage): + self.context['content'] = details.get('content') + + + def execute(self, document, chat_id, **kwargs): + get_buffer = FileBufferHandle().get_buffer + + self.context['document_list'] = document + content = [] + if document is None or not isinstance(document, list): + return NodeResult({'content': ''}, {}) + + application = self.workflow_manage.work_flow_post_handler.chat_info.application + + # doc文件中的图片保存 + def save_image(image_list): + for image in image_list: + meta = { + 'debug': False if application.id else True, + 'chat_id': chat_id, + 'application_id': str(application.id) if application.id else None, + 'file_id': str(image.id) + } + file = bytes_to_uploaded_file(image.image, image.image_name) + FileSerializer(data={'file': file, 'meta': meta}).upload() + + for doc in document: + file = QuerySet(File).filter(id=doc['file_id']).first() + buffer = io.BytesIO(file.get_byte().tobytes()) + buffer.name = doc['name'] # this is the important line + + for split_handle in (parse_table_handle_list + split_handles): + if split_handle.support(buffer, get_buffer): + # 回到文件头 + buffer.seek(0) + file_content = split_handle.get_content(buffer, save_image) + content.append('### ' + doc['name'] + '\n' + file_content) + break + + return NodeResult({'content': splitter.join(content)}, {}) + + def get_details(self, index: int, **kwargs): + content = self.context.get('content', '').split(splitter) + # 不保存content全部内容,因为content内容可能会很大 + return { + 'name': self.node.properties.get('stepName'), + "index": index, + 'run_time': self.context.get('run_time'), + 'type': self.node.type, + 'content': [file_content[:500] for file_content in content], + 'status': self.status, + 'err_message': self.err_message, + 'document_list': self.context.get('document_list') + } diff --git a/apps/application/flow/step_node/form_node/__init__.py b/apps/application/flow/step_node/form_node/__init__.py new file mode 100644 index 00000000000..ce04b64aea8 --- /dev/null +++ b/apps/application/flow/step_node/form_node/__init__.py @@ -0,0 +1,9 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: __init__.py.py + @date:2024/11/4 14:48 + @desc: +""" +from .impl import * \ No newline at end of file diff --git a/apps/application/flow/step_node/form_node/i_form_node.py b/apps/application/flow/step_node/form_node/i_form_node.py new file mode 100644 index 00000000000..7e82494293d --- /dev/null +++ b/apps/application/flow/step_node/form_node/i_form_node.py @@ -0,0 +1,35 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: i_form_node.py + @date:2024/11/4 14:48 + @desc: +""" +from typing import Type + +from rest_framework import serializers + +from application.flow.i_step_node import INode, NodeResult +from common.util.field_message import ErrMessage +from django.utils.translation import gettext_lazy as _ + + +class FormNodeParamsSerializer(serializers.Serializer): + form_field_list = serializers.ListField(required=True, error_messages=ErrMessage.list(_("Form Configuration"))) + form_content_format = serializers.CharField(required=True, error_messages=ErrMessage.char(_('Form output content'))) + form_data = serializers.DictField(required=False, allow_null=True, error_messages=ErrMessage.dict(_("Form Data"))) + + +class IFormNode(INode): + type = 'form-node' + view_type = 'single_view' + + def get_node_params_serializer_class(self) -> Type[serializers.Serializer]: + return FormNodeParamsSerializer + + def _run(self): + return self.execute(**self.node_params_serializer.data, **self.flow_params_serializer.data) + + def execute(self, form_field_list, form_content_format, form_data, **kwargs) -> NodeResult: + pass diff --git a/apps/application/flow/step_node/form_node/impl/__init__.py b/apps/application/flow/step_node/form_node/impl/__init__.py new file mode 100644 index 00000000000..4cea85e1d9e --- /dev/null +++ b/apps/application/flow/step_node/form_node/impl/__init__.py @@ -0,0 +1,9 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: __init__.py.py + @date:2024/11/4 14:49 + @desc: +""" +from .base_form_node import BaseFormNode diff --git a/apps/application/flow/step_node/form_node/impl/base_form_node.py b/apps/application/flow/step_node/form_node/impl/base_form_node.py new file mode 100644 index 00000000000..dcf35dd3cfd --- /dev/null +++ b/apps/application/flow/step_node/form_node/impl/base_form_node.py @@ -0,0 +1,107 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: base_form_node.py + @date:2024/11/4 14:52 + @desc: +""" +import json +import time +from typing import Dict, List + +from langchain_core.prompts import PromptTemplate + +from application.flow.common import Answer +from application.flow.i_step_node import NodeResult +from application.flow.step_node.form_node.i_form_node import IFormNode + + +def write_context(step_variable: Dict, global_variable: Dict, node, workflow): + if step_variable is not None: + for key in step_variable: + node.context[key] = step_variable[key] + if workflow.is_result(node, NodeResult(step_variable, global_variable)) and 'result' in step_variable: + result = step_variable['result'] + yield result + node.answer_text = result + node.context['run_time'] = time.time() - node.context['start_time'] + + +class BaseFormNode(IFormNode): + def save_context(self, details, workflow_manage): + form_data = details.get('form_data', None) + self.context['result'] = details.get('result') + self.context['form_content_format'] = details.get('form_content_format') + self.context['form_field_list'] = details.get('form_field_list') + self.context['run_time'] = details.get('run_time') + self.context['start_time'] = details.get('start_time') + self.context['form_data'] = form_data + self.context['is_submit'] = details.get('is_submit') + if self.node_params.get('is_result', False): + self.answer_text = details.get('result') + if form_data is not None: + for key in form_data: + self.context[key] = form_data[key] + + def execute(self, form_field_list, form_content_format, form_data, **kwargs) -> NodeResult: + if form_data is not None: + self.context['is_submit'] = True + self.context['form_data'] = form_data + for key in form_data: + self.context[key] = form_data.get(key) + else: + self.context['is_submit'] = False + form_setting = {"form_field_list": form_field_list, "runtime_node_id": self.runtime_node_id, + "chat_record_id": self.flow_params_serializer.data.get("chat_record_id"), + "is_submit": self.context.get("is_submit", False)} + form = f'{json.dumps(form_setting, ensure_ascii=False)}' + context = self.workflow_manage.get_workflow_content() + form_content_format = self.workflow_manage.reset_prompt(form_content_format) + prompt_template = PromptTemplate.from_template(form_content_format, template_format='jinja2') + value = prompt_template.format(form=form, context=context) + return NodeResult( + {'result': value, 'form_field_list': form_field_list, 'form_content_format': form_content_format}, {}, + _write_context=write_context) + + def get_answer_list(self) -> List[Answer] | None: + form_content_format = self.context.get('form_content_format') + form_field_list = self.context.get('form_field_list') + form_setting = {"form_field_list": form_field_list, "runtime_node_id": self.runtime_node_id, + "chat_record_id": self.flow_params_serializer.data.get("chat_record_id"), + 'form_data': self.context.get('form_data', {}), + "is_submit": self.context.get("is_submit", False)} + form = f'{json.dumps(form_setting, ensure_ascii=False)}' + context = self.workflow_manage.get_workflow_content() + form_content_format = self.workflow_manage.reset_prompt(form_content_format) + prompt_template = PromptTemplate.from_template(form_content_format, template_format='jinja2') + value = prompt_template.format(form=form, context=context) + return [Answer(value, self.view_type, self.runtime_node_id, self.workflow_params['chat_record_id'], None, + self.runtime_node_id, '')] + + def get_details(self, index: int, **kwargs): + form_content_format = self.context.get('form_content_format') + form_field_list = self.context.get('form_field_list') + form_setting = {"form_field_list": form_field_list, "runtime_node_id": self.runtime_node_id, + "chat_record_id": self.flow_params_serializer.data.get("chat_record_id"), + 'form_data': self.context.get('form_data', {}), + "is_submit": self.context.get("is_submit", False)} + form = f'{json.dumps(form_setting, ensure_ascii=False)}' + context = self.workflow_manage.get_workflow_content() + form_content_format = self.workflow_manage.reset_prompt(form_content_format) + prompt_template = PromptTemplate.from_template(form_content_format, template_format='jinja2') + value = prompt_template.format(form=form, context=context) + return { + 'name': self.node.properties.get('stepName'), + "index": index, + "result": value, + "form_content_format": self.context.get('form_content_format'), + "form_field_list": self.context.get('form_field_list'), + 'form_data': self.context.get('form_data'), + 'start_time': self.context.get('start_time'), + 'is_submit': self.context.get('is_submit'), + 'run_time': self.context.get('run_time'), + 'type': self.node.type, + 'status': self.status, + 'err_message': self.err_message + } diff --git a/apps/application/flow/step_node/function_lib_node/__init__.py b/apps/application/flow/step_node/function_lib_node/__init__.py new file mode 100644 index 00000000000..7422965c365 --- /dev/null +++ b/apps/application/flow/step_node/function_lib_node/__init__.py @@ -0,0 +1,9 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: __init__.py + @date:2024/8/8 17:45 + @desc: +""" +from .impl import * \ No newline at end of file diff --git a/apps/application/flow/step_node/function_lib_node/i_function_lib_node.py b/apps/application/flow/step_node/function_lib_node/i_function_lib_node.py new file mode 100644 index 00000000000..c84782ff6a9 --- /dev/null +++ b/apps/application/flow/step_node/function_lib_node/i_function_lib_node.py @@ -0,0 +1,48 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: i_function_lib_node.py + @date:2024/8/8 16:21 + @desc: +""" +from typing import Type + +from django.db.models import QuerySet +from rest_framework import serializers + +from application.flow.i_step_node import INode, NodeResult +from common.field.common import ObjectField +from common.util.field_message import ErrMessage +from function_lib.models.function import FunctionLib +from django.utils.translation import gettext_lazy as _ + + +class InputField(serializers.Serializer): + name = serializers.CharField(required=True, error_messages=ErrMessage.char(_('Variable Name'))) + value = ObjectField(required=True, error_messages=ErrMessage.char(_("Variable Value")), model_type_list=[str, list]) + + +class FunctionLibNodeParamsSerializer(serializers.Serializer): + function_lib_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid(_('Library ID'))) + input_field_list = InputField(required=True, many=True) + is_result = serializers.BooleanField(required=False, error_messages=ErrMessage.boolean(_('Whether to return content'))) + + def is_valid(self, *, raise_exception=False): + super().is_valid(raise_exception=True) + f_lib = QuerySet(FunctionLib).filter(id=self.data.get('function_lib_id')).first() + if f_lib is None: + raise Exception(_('The function has been deleted')) + + +class IFunctionLibNode(INode): + type = 'function-lib-node' + + def get_node_params_serializer_class(self) -> Type[serializers.Serializer]: + return FunctionLibNodeParamsSerializer + + def _run(self): + return self.execute(**self.node_params_serializer.data, **self.flow_params_serializer.data) + + def execute(self, function_lib_id, input_field_list, **kwargs) -> NodeResult: + pass diff --git a/apps/application/flow/step_node/function_lib_node/impl/__init__.py b/apps/application/flow/step_node/function_lib_node/impl/__init__.py new file mode 100644 index 00000000000..96681474f19 --- /dev/null +++ b/apps/application/flow/step_node/function_lib_node/impl/__init__.py @@ -0,0 +1,9 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: __init__.py + @date:2024/8/8 17:48 + @desc: +""" +from .base_function_lib_node import BaseFunctionLibNodeNode diff --git a/apps/application/flow/step_node/function_lib_node/impl/base_function_lib_node.py b/apps/application/flow/step_node/function_lib_node/impl/base_function_lib_node.py new file mode 100644 index 00000000000..341bb91da63 --- /dev/null +++ b/apps/application/flow/step_node/function_lib_node/impl/base_function_lib_node.py @@ -0,0 +1,150 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: base_function_lib_node.py + @date:2024/8/8 17:49 + @desc: +""" +import json +import time +from typing import Dict + +from django.db.models import QuerySet +from django.utils.translation import gettext as _ + +from application.flow.i_step_node import NodeResult +from application.flow.step_node.function_lib_node.i_function_lib_node import IFunctionLibNode +from common.exception.app_exception import AppApiException +from common.util.function_code import FunctionExecutor +from common.util.rsa_util import rsa_long_decrypt +from function_lib.models.function import FunctionLib +from smartdoc.const import CONFIG + +function_executor = FunctionExecutor(CONFIG.get('SANDBOX')) + + +def write_context(step_variable: Dict, global_variable: Dict, node, workflow): + if step_variable is not None: + for key in step_variable: + node.context[key] = step_variable[key] + if workflow.is_result(node, NodeResult(step_variable, global_variable)) and 'result' in step_variable: + result = str(step_variable['result']) + '\n' + yield result + node.answer_text = result + node.context['run_time'] = time.time() - node.context['start_time'] + + +def get_field_value(debug_field_list, name, is_required): + result = [field for field in debug_field_list if field.get('name') == name] + if len(result) > 0: + return result[-1]['value'] + if is_required: + raise AppApiException(500, _('Field: {name} No value set').format(name=name)) + return None + + +def valid_reference_value(_type, value, name): + if _type == 'int': + instance_type = int | float + elif _type == 'float': + instance_type = float | int + elif _type == 'dict': + instance_type = dict + elif _type == 'array': + instance_type = list + elif _type == 'string': + instance_type = str + else: + raise Exception(_('Field: {name} Type: {_type} Value: {value} Unsupported types').format(name=name, + _type=_type)) + if not isinstance(value, instance_type): + raise Exception( + _('Field: {name} Type: {_type} Value: {value} Type error').format(name=name, _type=_type, + value=value)) + + +def convert_value(name: str, value, _type, is_required, source, node): + if not is_required and (value is None or (isinstance(value, str) and len(value) == 0)): + return None + if not is_required and source == 'reference' and (value is None or len(value) == 0): + return None + if source == 'reference': + value = node.workflow_manage.get_reference_field( + value[0], + value[1:]) + valid_reference_value(_type, value, name) + if _type == 'int': + return int(value) + if _type == 'float': + return float(value) + return value + try: + if _type == 'int': + return int(value) + if _type == 'float': + return float(value) + if _type == 'dict': + v = json.loads(value) + if isinstance(v, dict): + return v + raise Exception(_('type error')) + if _type == 'array': + v = json.loads(value) + if isinstance(v, list): + return v + raise Exception(_('type error')) + return value + except Exception as e: + raise Exception( + _('Field: {name} Type: {_type} Value: {value} Type error').format(name=name, _type=_type, + value=value)) + + +def valid_function(function_lib, user_id): + if function_lib is None: + raise Exception(_('Function does not exist')) + if function_lib.permission_type == 'PRIVATE' and str(function_lib.user_id) != str(user_id): + raise Exception(_('No permission to use this function {name}').format(name=function_lib.name)) + if not function_lib.is_active: + raise Exception(_('Function {name} is unavailable').format(name=function_lib.name)) + + +class BaseFunctionLibNodeNode(IFunctionLibNode): + def save_context(self, details, workflow_manage): + self.context['result'] = details.get('result') + if self.node_params.get('is_result'): + self.answer_text = str(details.get('result')) + + def execute(self, function_lib_id, input_field_list, **kwargs) -> NodeResult: + function_lib = QuerySet(FunctionLib).filter(id=function_lib_id).first() + valid_function(function_lib, self.flow_params_serializer.data.get('user_id')) + params = {field.get('name'): convert_value(field.get('name'), field.get('value'), field.get('type'), + field.get('is_required'), + field.get('source'), self) + for field in + [{'value': get_field_value(input_field_list, field.get('name'), field.get('is_required'), + ), **field} + for field in + function_lib.input_field_list]} + + self.context['params'] = params + # 合并初始化参数 + if function_lib.init_params is not None: + all_params = json.loads(rsa_long_decrypt(function_lib.init_params)) | params + else: + all_params = params + result = function_executor.exec_code(function_lib.code, all_params) + return NodeResult({'result': result}, {}, _write_context=write_context) + + def get_details(self, index: int, **kwargs): + return { + 'name': self.node.properties.get('stepName'), + "index": index, + "result": self.context.get('result'), + "params": self.context.get('params'), + 'run_time': self.context.get('run_time'), + 'type': self.node.type, + 'status': self.status, + 'err_message': self.err_message + } diff --git a/apps/application/flow/step_node/function_node/__init__.py b/apps/application/flow/step_node/function_node/__init__.py new file mode 100644 index 00000000000..ebfbe8d8bb4 --- /dev/null +++ b/apps/application/flow/step_node/function_node/__init__.py @@ -0,0 +1,9 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: __init__.py.py + @date:2024/8/13 10:43 + @desc: +""" +from .impl import * \ No newline at end of file diff --git a/apps/application/flow/step_node/function_node/i_function_node.py b/apps/application/flow/step_node/function_node/i_function_node.py new file mode 100644 index 00000000000..bbaae6c73fe --- /dev/null +++ b/apps/application/flow/step_node/function_node/i_function_node.py @@ -0,0 +1,63 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: i_function_lib_node.py + @date:2024/8/8 16:21 + @desc: +""" +import re +from typing import Type + +from django.core import validators +from rest_framework import serializers + +from application.flow.i_step_node import INode, NodeResult +from common.exception.app_exception import AppApiException +from common.field.common import ObjectField +from common.util.field_message import ErrMessage +from django.utils.translation import gettext_lazy as _ +from rest_framework.utils.formatting import lazy_format + + +class InputField(serializers.Serializer): + name = serializers.CharField(required=True, error_messages=ErrMessage.char(_('Variable Name'))) + is_required = serializers.BooleanField(required=True, error_messages=ErrMessage.boolean(_("Is this field required"))) + type = serializers.CharField(required=True, error_messages=ErrMessage.char(_("type")), validators=[ + validators.RegexValidator(regex=re.compile("^string|int|dict|array|float$"), + message=_("The field only supports string|int|dict|array|float"), code=500) + ]) + source = serializers.CharField(required=True, error_messages=ErrMessage.char(_("source")), validators=[ + validators.RegexValidator(regex=re.compile("^custom|reference$"), + message=_("The field only supports custom|reference"), code=500) + ]) + value = ObjectField(required=True, error_messages=ErrMessage.char(_("Variable Value")), model_type_list=[str, list]) + + def is_valid(self, *, raise_exception=False): + super().is_valid(raise_exception=True) + is_required = self.data.get('is_required') + if is_required and self.data.get('value') is None: + message = lazy_format(_('{field}, this field is required.'), field=self.data.get("name")) + raise AppApiException(500, message) + + +class FunctionNodeParamsSerializer(serializers.Serializer): + input_field_list = InputField(required=True, many=True) + code = serializers.CharField(required=True, error_messages=ErrMessage.char(_("function"))) + is_result = serializers.BooleanField(required=False, error_messages=ErrMessage.boolean(_('Whether to return content'))) + + def is_valid(self, *, raise_exception=False): + super().is_valid(raise_exception=True) + + +class IFunctionNode(INode): + type = 'function-node' + + def get_node_params_serializer_class(self) -> Type[serializers.Serializer]: + return FunctionNodeParamsSerializer + + def _run(self): + return self.execute(**self.node_params_serializer.data, **self.flow_params_serializer.data) + + def execute(self, input_field_list, code, **kwargs) -> NodeResult: + pass diff --git a/apps/application/flow/step_node/function_node/impl/__init__.py b/apps/application/flow/step_node/function_node/impl/__init__.py new file mode 100644 index 00000000000..1a096368f84 --- /dev/null +++ b/apps/application/flow/step_node/function_node/impl/__init__.py @@ -0,0 +1,9 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: __init__.py.py + @date:2024/8/13 11:19 + @desc: +""" +from .base_function_node import BaseFunctionNodeNode diff --git a/apps/application/flow/step_node/function_node/impl/base_function_node.py b/apps/application/flow/step_node/function_node/impl/base_function_node.py new file mode 100644 index 00000000000..d659227f1ee --- /dev/null +++ b/apps/application/flow/step_node/function_node/impl/base_function_node.py @@ -0,0 +1,108 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: base_function_lib_node.py + @date:2024/8/8 17:49 + @desc: +""" +import json +import time + +from typing import Dict + +from application.flow.i_step_node import NodeResult +from application.flow.step_node.function_node.i_function_node import IFunctionNode +from common.exception.app_exception import AppApiException +from common.util.function_code import FunctionExecutor +from smartdoc.const import CONFIG + +function_executor = FunctionExecutor(CONFIG.get('SANDBOX')) + + +def write_context(step_variable: Dict, global_variable: Dict, node, workflow): + if step_variable is not None: + for key in step_variable: + node.context[key] = step_variable[key] + if workflow.is_result(node, NodeResult(step_variable, global_variable)) and 'result' in step_variable: + result = str(step_variable['result']) + '\n' + yield result + node.answer_text = result + node.context['run_time'] = time.time() - node.context['start_time'] + + +def valid_reference_value(_type, value, name): + if _type == 'int': + instance_type = int | float + elif _type == 'float': + instance_type = float | int + elif _type == 'dict': + instance_type = dict + elif _type == 'array': + instance_type = list + elif _type == 'string': + instance_type = str + else: + raise Exception(500, f'字段:{name}类型:{_type} 不支持的类型') + if not isinstance(value, instance_type): + raise Exception(f'字段:{name}类型:{_type}值:{value}类型错误') + + +def convert_value(name: str, value, _type, is_required, source, node): + if not is_required and (value is None or (isinstance(value, str) and len(value) == 0)): + return None + if source == 'reference': + value = node.workflow_manage.get_reference_field( + value[0], + value[1:]) + valid_reference_value(_type, value, name) + if _type == 'int': + return int(value) + if _type == 'float': + return float(value) + return value + try: + if _type == 'int': + return int(value) + if _type == 'float': + return float(value) + if _type == 'dict': + v = json.loads(value) + if isinstance(v, dict): + return v + raise Exception("类型错误") + if _type == 'array': + v = json.loads(value) + if isinstance(v, list): + return v + raise Exception("类型错误") + return value + except Exception as e: + raise Exception(f'字段:{name}类型:{_type}值:{value}类型错误') + + +class BaseFunctionNodeNode(IFunctionNode): + def save_context(self, details, workflow_manage): + self.context['result'] = details.get('result') + if self.node_params.get('is_result', False): + self.answer_text = str(details.get('result')) + + def execute(self, input_field_list, code, **kwargs) -> NodeResult: + params = {field.get('name'): convert_value(field.get('name'), field.get('value'), field.get('type'), + field.get('is_required'), field.get('source'), self) + for field in input_field_list} + result = function_executor.exec_code(code, params) + self.context['params'] = params + return NodeResult({'result': result}, {}, _write_context=write_context) + + def get_details(self, index: int, **kwargs): + return { + 'name': self.node.properties.get('stepName'), + "index": index, + "result": self.context.get('result'), + "params": self.context.get('params'), + 'run_time': self.context.get('run_time'), + 'type': self.node.type, + 'status': self.status, + 'err_message': self.err_message + } diff --git a/apps/application/flow/step_node/image_generate_step_node/__init__.py b/apps/application/flow/step_node/image_generate_step_node/__init__.py new file mode 100644 index 00000000000..f3feecc9ce2 --- /dev/null +++ b/apps/application/flow/step_node/image_generate_step_node/__init__.py @@ -0,0 +1,3 @@ +# coding=utf-8 + +from .impl import * diff --git a/apps/application/flow/step_node/image_generate_step_node/i_image_generate_node.py b/apps/application/flow/step_node/image_generate_step_node/i_image_generate_node.py new file mode 100644 index 00000000000..56a214cf9b9 --- /dev/null +++ b/apps/application/flow/step_node/image_generate_step_node/i_image_generate_node.py @@ -0,0 +1,45 @@ +# coding=utf-8 + +from typing import Type + +from rest_framework import serializers + +from application.flow.i_step_node import INode, NodeResult +from common.util.field_message import ErrMessage +from django.utils.translation import gettext_lazy as _ + + +class ImageGenerateNodeSerializer(serializers.Serializer): + model_id = serializers.CharField(required=True, error_messages=ErrMessage.char(_("Model id"))) + + prompt = serializers.CharField(required=True, error_messages=ErrMessage.char(_("Prompt word (positive)"))) + + negative_prompt = serializers.CharField(required=False, error_messages=ErrMessage.char(_("Prompt word (negative)")), + allow_null=True, allow_blank=True, ) + # 多轮对话数量 + dialogue_number = serializers.IntegerField(required=False, default=0, + error_messages=ErrMessage.integer(_("Number of multi-round conversations"))) + + dialogue_type = serializers.CharField(required=False, default='NODE', + error_messages=ErrMessage.char(_("Conversation storage type"))) + + is_result = serializers.BooleanField(required=False, error_messages=ErrMessage.boolean(_('Whether to return content'))) + + model_params_setting = serializers.JSONField(required=False, default=dict, + error_messages=ErrMessage.json(_("Model parameter settings"))) + + +class IImageGenerateNode(INode): + type = 'image-generate-node' + + def get_node_params_serializer_class(self) -> Type[serializers.Serializer]: + return ImageGenerateNodeSerializer + + def _run(self): + return self.execute(**self.node_params_serializer.data, **self.flow_params_serializer.data) + + def execute(self, model_id, prompt, negative_prompt, dialogue_number, dialogue_type, history_chat_record, chat_id, + model_params_setting, + chat_record_id, + **kwargs) -> NodeResult: + pass diff --git a/apps/application/flow/step_node/image_generate_step_node/impl/__init__.py b/apps/application/flow/step_node/image_generate_step_node/impl/__init__.py new file mode 100644 index 00000000000..14a21a9159c --- /dev/null +++ b/apps/application/flow/step_node/image_generate_step_node/impl/__init__.py @@ -0,0 +1,3 @@ +# coding=utf-8 + +from .base_image_generate_node import BaseImageGenerateNode diff --git a/apps/application/flow/step_node/image_generate_step_node/impl/base_image_generate_node.py b/apps/application/flow/step_node/image_generate_step_node/impl/base_image_generate_node.py new file mode 100644 index 00000000000..16423eafd61 --- /dev/null +++ b/apps/application/flow/step_node/image_generate_step_node/impl/base_image_generate_node.py @@ -0,0 +1,122 @@ +# coding=utf-8 +from functools import reduce +from typing import List + +import requests +from langchain_core.messages import BaseMessage, HumanMessage, AIMessage + +from application.flow.i_step_node import NodeResult +from application.flow.step_node.image_generate_step_node.i_image_generate_node import IImageGenerateNode +from common.util.common import bytes_to_uploaded_file +from dataset.serializers.file_serializers import FileSerializer +from setting.models_provider.tools import get_model_instance_by_model_user_id + + +class BaseImageGenerateNode(IImageGenerateNode): + def save_context(self, details, workflow_manage): + self.context['answer'] = details.get('answer') + self.context['question'] = details.get('question') + if self.node_params.get('is_result', False): + self.answer_text = details.get('answer') + + def execute(self, model_id, prompt, negative_prompt, dialogue_number, dialogue_type, history_chat_record, chat_id, + model_params_setting, + chat_record_id, + **kwargs) -> NodeResult: + print(model_params_setting) + application = self.workflow_manage.work_flow_post_handler.chat_info.application + tti_model = get_model_instance_by_model_user_id(model_id, self.flow_params_serializer.data.get('user_id'), + **model_params_setting) + history_message = self.get_history_message(history_chat_record, dialogue_number) + self.context['history_message'] = history_message + question = self.generate_prompt_question(prompt) + self.context['question'] = question + message_list = self.generate_message_list(question, history_message) + self.context['message_list'] = message_list + self.context['dialogue_type'] = dialogue_type + print(message_list) + image_urls = tti_model.generate_image(question, negative_prompt) + # 保存图片 + file_urls = [] + for image_url in image_urls: + file_name = 'generated_image.png' + file = bytes_to_uploaded_file(requests.get(image_url).content, file_name) + meta = { + 'debug': False if application.id else True, + 'chat_id': chat_id, + 'application_id': str(application.id) if application.id else None, + } + file_url = FileSerializer(data={'file': file, 'meta': meta}).upload() + file_urls.append(file_url) + self.context['image_list'] = [{'file_id': path.split('/')[-1], 'url': path} for path in file_urls] + answer = ' '.join([f"![Image]({path})" for path in file_urls]) + return NodeResult({'answer': answer, 'chat_model': tti_model, 'message_list': message_list, + 'image': [{'file_id': path.split('/')[-1], 'url': path} for path in file_urls], + 'history_message': history_message, 'question': question}, {}) + + def generate_history_ai_message(self, chat_record): + for val in chat_record.details.values(): + if self.node.id == val['node_id'] and 'image_list' in val: + if val['dialogue_type'] == 'WORKFLOW': + return chat_record.get_ai_message() + image_list = val['image_list'] + return AIMessage(content=[ + *[{'type': 'image_url', 'image_url': {'url': f'{file_url}'}} for file_url in image_list] + ]) + return chat_record.get_ai_message() + + def get_history_message(self, history_chat_record, dialogue_number): + start_index = len(history_chat_record) - dialogue_number + history_message = reduce(lambda x, y: [*x, *y], [ + [self.generate_history_human_message(history_chat_record[index]), + self.generate_history_ai_message(history_chat_record[index])] + for index in + range(start_index if start_index > 0 else 0, len(history_chat_record))], []) + return history_message + + def generate_history_human_message(self, chat_record): + + for data in chat_record.details.values(): + if self.node.id == data['node_id'] and 'image_list' in data: + image_list = data['image_list'] + if len(image_list) == 0 or data['dialogue_type'] == 'WORKFLOW': + return HumanMessage(content=chat_record.problem_text) + return HumanMessage(content=data['question']) + return HumanMessage(content=chat_record.problem_text) + + def generate_prompt_question(self, prompt): + return self.workflow_manage.generate_prompt(prompt) + + def generate_message_list(self, question: str, history_message): + return [ + *history_message, + question + ] + + @staticmethod + def reset_message_list(message_list: List[BaseMessage], answer_text): + result = [{'role': 'user' if isinstance(message, HumanMessage) else 'ai', 'content': message.content} for + message + in + message_list] + result.append({'role': 'ai', 'content': answer_text}) + return result + + def get_details(self, index: int, **kwargs): + return { + 'name': self.node.properties.get('stepName'), + "index": index, + 'run_time': self.context.get('run_time'), + 'history_message': [{'content': message.content, 'role': message.type} for message in + (self.context.get('history_message') if self.context.get( + 'history_message') is not None else [])], + 'question': self.context.get('question'), + 'answer': self.context.get('answer'), + 'type': self.node.type, + 'message_tokens': self.context.get('message_tokens'), + 'answer_tokens': self.context.get('answer_tokens'), + 'status': self.status, + 'err_message': self.err_message, + 'image_list': self.context.get('image_list'), + 'dialogue_type': self.context.get('dialogue_type') + } diff --git a/apps/application/flow/step_node/image_understand_step_node/__init__.py b/apps/application/flow/step_node/image_understand_step_node/__init__.py new file mode 100644 index 00000000000..f3feecc9ce2 --- /dev/null +++ b/apps/application/flow/step_node/image_understand_step_node/__init__.py @@ -0,0 +1,3 @@ +# coding=utf-8 + +from .impl import * diff --git a/apps/application/flow/step_node/image_understand_step_node/i_image_understand_node.py b/apps/application/flow/step_node/image_understand_step_node/i_image_understand_node.py new file mode 100644 index 00000000000..5ef4c101708 --- /dev/null +++ b/apps/application/flow/step_node/image_understand_step_node/i_image_understand_node.py @@ -0,0 +1,46 @@ +# coding=utf-8 + +from typing import Type + +from rest_framework import serializers + +from application.flow.i_step_node import INode, NodeResult +from common.util.field_message import ErrMessage +from django.utils.translation import gettext_lazy as _ + + +class ImageUnderstandNodeSerializer(serializers.Serializer): + model_id = serializers.CharField(required=True, error_messages=ErrMessage.char(_("Model id"))) + system = serializers.CharField(required=False, allow_blank=True, allow_null=True, + error_messages=ErrMessage.char(_("Role Setting"))) + prompt = serializers.CharField(required=True, error_messages=ErrMessage.char(_("Prompt word"))) + # 多轮对话数量 + dialogue_number = serializers.IntegerField(required=True, error_messages=ErrMessage.integer(_("Number of multi-round conversations"))) + + dialogue_type = serializers.CharField(required=True, error_messages=ErrMessage.char(_("Conversation storage type"))) + + is_result = serializers.BooleanField(required=False, error_messages=ErrMessage.boolean(_('Whether to return content'))) + + image_list = serializers.ListField(required=False, error_messages=ErrMessage.list(_("picture"))) + + model_params_setting = serializers.JSONField(required=False, default=dict, + error_messages=ErrMessage.json(_("Model parameter settings"))) + + +class IImageUnderstandNode(INode): + type = 'image-understand-node' + + def get_node_params_serializer_class(self) -> Type[serializers.Serializer]: + return ImageUnderstandNodeSerializer + + def _run(self): + res = self.workflow_manage.get_reference_field(self.node_params_serializer.data.get('image_list')[0], + self.node_params_serializer.data.get('image_list')[1:]) + return self.execute(image=res, **self.node_params_serializer.data, **self.flow_params_serializer.data) + + def execute(self, model_id, system, prompt, dialogue_number, dialogue_type, history_chat_record, stream, chat_id, + model_params_setting, + chat_record_id, + image, + **kwargs) -> NodeResult: + pass diff --git a/apps/application/flow/step_node/image_understand_step_node/impl/__init__.py b/apps/application/flow/step_node/image_understand_step_node/impl/__init__.py new file mode 100644 index 00000000000..ba251283921 --- /dev/null +++ b/apps/application/flow/step_node/image_understand_step_node/impl/__init__.py @@ -0,0 +1,3 @@ +# coding=utf-8 + +from .base_image_understand_node import BaseImageUnderstandNode diff --git a/apps/application/flow/step_node/image_understand_step_node/impl/base_image_understand_node.py b/apps/application/flow/step_node/image_understand_step_node/impl/base_image_understand_node.py new file mode 100644 index 00000000000..44765bc4f93 --- /dev/null +++ b/apps/application/flow/step_node/image_understand_step_node/impl/base_image_understand_node.py @@ -0,0 +1,224 @@ +# coding=utf-8 +import base64 +import os +import time +from functools import reduce +from typing import List, Dict + +from django.db.models import QuerySet +from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage, AIMessage + +from application.flow.i_step_node import NodeResult, INode +from application.flow.step_node.image_understand_step_node.i_image_understand_node import IImageUnderstandNode +from dataset.models import File +from setting.models_provider.tools import get_model_instance_by_model_user_id +from imghdr import what + + +def _write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow, answer: str): + chat_model = node_variable.get('chat_model') + message_tokens = node_variable['usage_metadata']['output_tokens'] if 'usage_metadata' in node_variable else 0 + answer_tokens = chat_model.get_num_tokens(answer) + node.context['message_tokens'] = message_tokens + node.context['answer_tokens'] = answer_tokens + node.context['answer'] = answer + node.context['history_message'] = node_variable['history_message'] + node.context['question'] = node_variable['question'] + node.context['run_time'] = time.time() - node.context['start_time'] + if workflow.is_result(node, NodeResult(node_variable, workflow_variable)): + node.answer_text = answer + + +def write_context_stream(node_variable: Dict, workflow_variable: Dict, node: INode, workflow): + """ + 写入上下文数据 (流式) + @param node_variable: 节点数据 + @param workflow_variable: 全局数据 + @param node: 节点 + @param workflow: 工作流管理器 + """ + response = node_variable.get('result') + answer = '' + for chunk in response: + answer += chunk.content + yield chunk.content + _write_context(node_variable, workflow_variable, node, workflow, answer) + + +def write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow): + """ + 写入上下文数据 + @param node_variable: 节点数据 + @param workflow_variable: 全局数据 + @param node: 节点实例对象 + @param workflow: 工作流管理器 + """ + response = node_variable.get('result') + answer = response.content + _write_context(node_variable, workflow_variable, node, workflow, answer) + + +def file_id_to_base64(file_id: str): + file = QuerySet(File).filter(id=file_id).first() + file_bytes = file.get_byte() + base64_image = base64.b64encode(file_bytes).decode("utf-8") + return [base64_image, what(None, file_bytes.tobytes())] + + +class BaseImageUnderstandNode(IImageUnderstandNode): + def save_context(self, details, workflow_manage): + self.context['answer'] = details.get('answer') + self.context['question'] = details.get('question') + if self.node_params.get('is_result', False): + self.answer_text = details.get('answer') + + def execute(self, model_id, system, prompt, dialogue_number, dialogue_type, history_chat_record, stream, chat_id, + model_params_setting, + chat_record_id, + image, + **kwargs) -> NodeResult: + # 处理不正确的参数 + if image is None or not isinstance(image, list): + image = [] + print(model_params_setting) + image_model = get_model_instance_by_model_user_id(model_id, self.flow_params_serializer.data.get('user_id'), **model_params_setting) + # 执行详情中的历史消息不需要图片内容 + history_message = self.get_history_message_for_details(history_chat_record, dialogue_number) + self.context['history_message'] = history_message + question = self.generate_prompt_question(prompt) + self.context['question'] = question.content + # 生成消息列表, 真实的history_message + message_list = self.generate_message_list(image_model, system, prompt, + self.get_history_message(history_chat_record, dialogue_number), image) + self.context['message_list'] = message_list + self.context['image_list'] = image + self.context['dialogue_type'] = dialogue_type + if stream: + r = image_model.stream(message_list) + return NodeResult({'result': r, 'chat_model': image_model, 'message_list': message_list, + 'history_message': history_message, 'question': question.content}, {}, + _write_context=write_context_stream) + else: + r = image_model.invoke(message_list) + return NodeResult({'result': r, 'chat_model': image_model, 'message_list': message_list, + 'history_message': history_message, 'question': question.content}, {}, + _write_context=write_context) + + def get_history_message_for_details(self, history_chat_record, dialogue_number): + start_index = len(history_chat_record) - dialogue_number + history_message = reduce(lambda x, y: [*x, *y], [ + [self.generate_history_human_message_for_details(history_chat_record[index]), + self.generate_history_ai_message(history_chat_record[index])] + for index in + range(start_index if start_index > 0 else 0, len(history_chat_record))], []) + return history_message + + def generate_history_ai_message(self, chat_record): + for val in chat_record.details.values(): + if self.node.id == val['node_id'] and 'image_list' in val: + if val['dialogue_type'] == 'WORKFLOW': + return chat_record.get_ai_message() + return AIMessage(content=val['answer']) + return chat_record.get_ai_message() + + def generate_history_human_message_for_details(self, chat_record): + for data in chat_record.details.values(): + if self.node.id == data['node_id'] and 'image_list' in data: + image_list = data['image_list'] + if len(image_list) == 0 or data['dialogue_type'] == 'WORKFLOW': + return HumanMessage(content=chat_record.problem_text) + file_id_list = [image.get('file_id') for image in image_list] + return HumanMessage(content=[ + {'type': 'text', 'text': data['question']}, + *[{'type': 'image_url', 'image_url': {'url': f'/api/file/{file_id}'}} for file_id in file_id_list] + + ]) + return HumanMessage(content=chat_record.problem_text) + + def get_history_message(self, history_chat_record, dialogue_number): + start_index = len(history_chat_record) - dialogue_number + history_message = reduce(lambda x, y: [*x, *y], [ + [self.generate_history_human_message(history_chat_record[index]), + self.generate_history_ai_message(history_chat_record[index])] + for index in + range(start_index if start_index > 0 else 0, len(history_chat_record))], []) + return history_message + + def generate_history_human_message(self, chat_record): + + for data in chat_record.details.values(): + if self.node.id == data['node_id'] and 'image_list' in data: + image_list = data['image_list'] + if len(image_list) == 0 or data['dialogue_type'] == 'WORKFLOW': + return HumanMessage(content=chat_record.problem_text) + image_base64_list = [file_id_to_base64(image.get('file_id')) for image in image_list] + return HumanMessage( + content=[ + {'type': 'text', 'text': data['question']}, + *[{'type': 'image_url', 'image_url': {'url': f'data:image/{base64_image[1]};base64,{base64_image[0]}'}} for + base64_image in image_base64_list] + ]) + return HumanMessage(content=chat_record.problem_text) + + def generate_prompt_question(self, prompt): + return HumanMessage(self.workflow_manage.generate_prompt(prompt)) + + def generate_message_list(self, image_model, system: str, prompt: str, history_message, image): + if image is not None and len(image) > 0: + # 处理多张图片 + images = [] + for img in image: + file_id = img['file_id'] + file = QuerySet(File).filter(id=file_id).first() + image_bytes = file.get_byte() + base64_image = base64.b64encode(image_bytes).decode("utf-8") + image_format = what(None, image_bytes.tobytes()) + images.append({'type': 'image_url', 'image_url': {'url': f'data:image/{image_format};base64,{base64_image}'}}) + messages = [HumanMessage( + content=[ + {'type': 'text', 'text': self.workflow_manage.generate_prompt(prompt)}, + *images + ])] + else: + messages = [HumanMessage(self.workflow_manage.generate_prompt(prompt))] + + if system is not None and len(system) > 0: + return [ + SystemMessage(self.workflow_manage.generate_prompt(system)), + *history_message, + *messages + ] + else: + return [ + *history_message, + *messages + ] + + @staticmethod + def reset_message_list(message_list: List[BaseMessage], answer_text): + result = [{'role': 'user' if isinstance(message, HumanMessage) else 'ai', 'content': message.content} for + message + in + message_list] + result.append({'role': 'ai', 'content': answer_text}) + return result + + def get_details(self, index: int, **kwargs): + return { + 'name': self.node.properties.get('stepName'), + "index": index, + 'run_time': self.context.get('run_time'), + 'system': self.node_params.get('system'), + 'history_message': [{'content': message.content, 'role': message.type} for message in + (self.context.get('history_message') if self.context.get( + 'history_message') is not None else [])], + 'question': self.context.get('question'), + 'answer': self.context.get('answer'), + 'type': self.node.type, + 'message_tokens': self.context.get('message_tokens'), + 'answer_tokens': self.context.get('answer_tokens'), + 'status': self.status, + 'err_message': self.err_message, + 'image_list': self.context.get('image_list'), + 'dialogue_type': self.context.get('dialogue_type') + } diff --git a/apps/application/flow/step_node/mcp_node/__init__.py b/apps/application/flow/step_node/mcp_node/__init__.py new file mode 100644 index 00000000000..f3feecc9ce2 --- /dev/null +++ b/apps/application/flow/step_node/mcp_node/__init__.py @@ -0,0 +1,3 @@ +# coding=utf-8 + +from .impl import * diff --git a/apps/application/flow/step_node/mcp_node/i_mcp_node.py b/apps/application/flow/step_node/mcp_node/i_mcp_node.py new file mode 100644 index 00000000000..94cb4da7729 --- /dev/null +++ b/apps/application/flow/step_node/mcp_node/i_mcp_node.py @@ -0,0 +1,35 @@ +# coding=utf-8 + +from typing import Type + +from rest_framework import serializers + +from application.flow.i_step_node import INode, NodeResult +from common.util.field_message import ErrMessage +from django.utils.translation import gettext_lazy as _ + + +class McpNodeSerializer(serializers.Serializer): + mcp_servers = serializers.JSONField(required=True, + error_messages=ErrMessage.char(_("Mcp servers"))) + + mcp_server = serializers.CharField(required=True, + error_messages=ErrMessage.char(_("Mcp server"))) + + mcp_tool = serializers.CharField(required=True, error_messages=ErrMessage.char(_("Mcp tool"))) + + tool_params = serializers.DictField(required=True, + error_messages=ErrMessage.char(_("Tool parameters"))) + + +class IMcpNode(INode): + type = 'mcp-node' + + def get_node_params_serializer_class(self) -> Type[serializers.Serializer]: + return McpNodeSerializer + + def _run(self): + return self.execute(**self.node_params_serializer.data, **self.flow_params_serializer.data) + + def execute(self, mcp_servers, mcp_server, mcp_tool, tool_params, **kwargs) -> NodeResult: + pass diff --git a/apps/application/flow/step_node/mcp_node/impl/__init__.py b/apps/application/flow/step_node/mcp_node/impl/__init__.py new file mode 100644 index 00000000000..8c9a5ee197c --- /dev/null +++ b/apps/application/flow/step_node/mcp_node/impl/__init__.py @@ -0,0 +1,3 @@ +# coding=utf-8 + +from .base_mcp_node import BaseMcpNode diff --git a/apps/application/flow/step_node/mcp_node/impl/base_mcp_node.py b/apps/application/flow/step_node/mcp_node/impl/base_mcp_node.py new file mode 100644 index 00000000000..e49ef7019f6 --- /dev/null +++ b/apps/application/flow/step_node/mcp_node/impl/base_mcp_node.py @@ -0,0 +1,61 @@ +# coding=utf-8 +import asyncio +import json +from typing import List + +from langchain_mcp_adapters.client import MultiServerMCPClient + +from application.flow.i_step_node import NodeResult +from application.flow.step_node.mcp_node.i_mcp_node import IMcpNode + + +class BaseMcpNode(IMcpNode): + def save_context(self, details, workflow_manage): + self.context['result'] = details.get('result') + self.context['tool_params'] = details.get('tool_params') + self.context['mcp_tool'] = details.get('mcp_tool') + if self.node_params.get('is_result', False): + self.answer_text = details.get('result') + + def execute(self, mcp_servers, mcp_server, mcp_tool, tool_params, **kwargs) -> NodeResult: + servers = json.loads(mcp_servers) + params = json.loads(json.dumps(tool_params)) + params = self.handle_variables(params) + + async def call_tool(s, session, t, a): + async with MultiServerMCPClient(s) as client: + s = await client.sessions[session].call_tool(t, a) + return s + + res = asyncio.run(call_tool(servers, mcp_server, mcp_tool, params)) + return NodeResult( + {'result': [content.text for content in res.content], 'tool_params': params, 'mcp_tool': mcp_tool}, {}) + + def handle_variables(self, tool_params): + # 处理参数中的变量 + for k, v in tool_params.items(): + if type(v) == str: + tool_params[k] = self.workflow_manage.generate_prompt(tool_params[k]) + if type(v) == dict: + self.handle_variables(v) + if (type(v) == list) and (type(v[0]) == str): + tool_params[k] = self.get_reference_content(v) + return tool_params + + def get_reference_content(self, fields: List[str]): + return str(self.workflow_manage.get_reference_field( + fields[0], + fields[1:])) + + def get_details(self, index: int, **kwargs): + return { + 'name': self.node.properties.get('stepName'), + "index": index, + 'run_time': self.context.get('run_time'), + 'status': self.status, + 'err_message': self.err_message, + 'type': self.node.type, + 'mcp_tool': self.context.get('mcp_tool'), + 'tool_params': self.context.get('tool_params'), + 'result': self.context.get('result'), + } diff --git a/apps/application/flow/step_node/question_node/__init__.py b/apps/application/flow/step_node/question_node/__init__.py new file mode 100644 index 00000000000..98a1afcd904 --- /dev/null +++ b/apps/application/flow/step_node/question_node/__init__.py @@ -0,0 +1,9 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: __init__.py + @date:2024/6/11 15:30 + @desc: +""" +from .impl import * diff --git a/apps/application/flow/step_node/question_node/i_question_node.py b/apps/application/flow/step_node/question_node/i_question_node.py new file mode 100644 index 00000000000..57898bf2206 --- /dev/null +++ b/apps/application/flow/step_node/question_node/i_question_node.py @@ -0,0 +1,42 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: i_chat_node.py + @date:2024/6/4 13:58 + @desc: +""" +from typing import Type + +from rest_framework import serializers + +from application.flow.i_step_node import INode, NodeResult +from common.util.field_message import ErrMessage +from django.utils.translation import gettext_lazy as _ + + +class QuestionNodeSerializer(serializers.Serializer): + model_id = serializers.CharField(required=True, error_messages=ErrMessage.char(_("Model id"))) + system = serializers.CharField(required=False, allow_blank=True, allow_null=True, + error_messages=ErrMessage.char(_("Role Setting"))) + prompt = serializers.CharField(required=True, error_messages=ErrMessage.char(_("Prompt word"))) + # 多轮对话数量 + dialogue_number = serializers.IntegerField(required=True, error_messages=ErrMessage.integer(_("Number of multi-round conversations"))) + + is_result = serializers.BooleanField(required=False, error_messages=ErrMessage.boolean(_('Whether to return content'))) + model_params_setting = serializers.DictField(required=False, error_messages=ErrMessage.integer(_("Model parameter settings"))) + + +class IQuestionNode(INode): + type = 'question-node' + + def get_node_params_serializer_class(self) -> Type[serializers.Serializer]: + return QuestionNodeSerializer + + def _run(self): + return self.execute(**self.node_params_serializer.data, **self.flow_params_serializer.data) + + def execute(self, model_id, system, prompt, dialogue_number, history_chat_record, stream, chat_id, chat_record_id, + model_params_setting=None, + **kwargs) -> NodeResult: + pass diff --git a/apps/application/flow/step_node/question_node/impl/__init__.py b/apps/application/flow/step_node/question_node/impl/__init__.py new file mode 100644 index 00000000000..d85aa8724ac --- /dev/null +++ b/apps/application/flow/step_node/question_node/impl/__init__.py @@ -0,0 +1,9 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: __init__.py + @date:2024/6/11 15:35 + @desc: +""" +from .base_question_node import BaseQuestionNode diff --git a/apps/application/flow/step_node/question_node/impl/base_question_node.py b/apps/application/flow/step_node/question_node/impl/base_question_node.py new file mode 100644 index 00000000000..e1fd5b86069 --- /dev/null +++ b/apps/application/flow/step_node/question_node/impl/base_question_node.py @@ -0,0 +1,159 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: base_question_node.py + @date:2024/6/4 14:30 + @desc: +""" +import re +import time +from functools import reduce +from typing import List, Dict + +from django.db.models import QuerySet +from langchain.schema import HumanMessage, SystemMessage +from langchain_core.messages import BaseMessage + +from application.flow.i_step_node import NodeResult, INode +from application.flow.step_node.question_node.i_question_node import IQuestionNode +from setting.models import Model +from setting.models_provider import get_model_credential +from setting.models_provider.tools import get_model_instance_by_model_user_id + + +def _write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow, answer: str): + chat_model = node_variable.get('chat_model') + message_tokens = chat_model.get_num_tokens_from_messages(node_variable.get('message_list')) + answer_tokens = chat_model.get_num_tokens(answer) + node.context['message_tokens'] = message_tokens + node.context['answer_tokens'] = answer_tokens + node.context['answer'] = answer + node.context['history_message'] = node_variable['history_message'] + node.context['question'] = node_variable['question'] + node.context['run_time'] = time.time() - node.context['start_time'] + if workflow.is_result(node, NodeResult(node_variable, workflow_variable)): + node.answer_text = answer + + +def write_context_stream(node_variable: Dict, workflow_variable: Dict, node: INode, workflow): + """ + 写入上下文数据 (流式) + @param node_variable: 节点数据 + @param workflow_variable: 全局数据 + @param node: 节点 + @param workflow: 工作流管理器 + """ + response = node_variable.get('result') + answer = '' + for chunk in response: + answer += chunk.content + yield chunk.content + _write_context(node_variable, workflow_variable, node, workflow, answer) + + +def write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow): + """ + 写入上下文数据 + @param node_variable: 节点数据 + @param workflow_variable: 全局数据 + @param node: 节点实例对象 + @param workflow: 工作流管理器 + """ + response = node_variable.get('result') + answer = response.content + _write_context(node_variable, workflow_variable, node, workflow, answer) + + +def get_default_model_params_setting(model_id): + model = QuerySet(Model).filter(id=model_id).first() + credential = get_model_credential(model.provider, model.model_type, model.model_name) + model_params_setting = credential.get_model_params_setting_form( + model.model_name).get_default_form_data() + return model_params_setting + + +class BaseQuestionNode(IQuestionNode): + def save_context(self, details, workflow_manage): + self.context['run_time'] = details.get('run_time') + self.context['question'] = details.get('question') + self.context['answer'] = details.get('answer') + self.context['message_tokens'] = details.get('message_tokens') + self.context['answer_tokens'] = details.get('answer_tokens') + if self.node_params.get('is_result', False): + self.answer_text = details.get('answer') + + def execute(self, model_id, system, prompt, dialogue_number, history_chat_record, stream, chat_id, chat_record_id, + model_params_setting=None, + **kwargs) -> NodeResult: + if model_params_setting is None: + model_params_setting = get_default_model_params_setting(model_id) + chat_model = get_model_instance_by_model_user_id(model_id, self.flow_params_serializer.data.get('user_id'), + **model_params_setting) + history_message = self.get_history_message(history_chat_record, dialogue_number) + self.context['history_message'] = history_message + question = self.generate_prompt_question(prompt) + self.context['question'] = question.content + system = self.workflow_manage.generate_prompt(system) + self.context['system'] = system + message_list = self.generate_message_list(system, prompt, history_message) + self.context['message_list'] = message_list + if stream: + r = chat_model.stream(message_list) + return NodeResult({'result': r, 'chat_model': chat_model, 'message_list': message_list, + 'history_message': history_message, 'question': question.content}, {}, + _write_context=write_context_stream) + else: + r = chat_model.invoke(message_list) + return NodeResult({'result': r, 'chat_model': chat_model, 'message_list': message_list, + 'history_message': history_message, 'question': question.content}, {}, + _write_context=write_context) + + @staticmethod + def get_history_message(history_chat_record, dialogue_number): + start_index = len(history_chat_record) - dialogue_number + history_message = reduce(lambda x, y: [*x, *y], [ + [history_chat_record[index].get_human_message(), history_chat_record[index].get_ai_message()] + for index in + range(start_index if start_index > 0 else 0, len(history_chat_record))], []) + for message in history_message: + if isinstance(message.content, str): + message.content = re.sub('[\d\D]*?<\/form_rander>', '', message.content) + return history_message + + def generate_prompt_question(self, prompt): + return HumanMessage(self.workflow_manage.generate_prompt(prompt)) + + def generate_message_list(self, system: str, prompt: str, history_message): + if system is None or len(system) == 0: + return [SystemMessage(self.workflow_manage.generate_prompt(system)), *history_message, + HumanMessage(self.workflow_manage.generate_prompt(prompt))] + else: + return [*history_message, HumanMessage(self.workflow_manage.generate_prompt(prompt))] + + @staticmethod + def reset_message_list(message_list: List[BaseMessage], answer_text): + result = [{'role': 'user' if isinstance(message, HumanMessage) else 'ai', 'content': message.content} for + message + in + message_list] + result.append({'role': 'ai', 'content': answer_text}) + return result + + def get_details(self, index: int, **kwargs): + return { + 'name': self.node.properties.get('stepName'), + "index": index, + 'run_time': self.context.get('run_time'), + 'system': self.context.get('system'), + 'history_message': [{'content': message.content, 'role': message.type} for message in + (self.context.get('history_message') if self.context.get( + 'history_message') is not None else [])], + 'question': self.context.get('question'), + 'answer': self.context.get('answer'), + 'type': self.node.type, + 'message_tokens': self.context.get('message_tokens'), + 'answer_tokens': self.context.get('answer_tokens'), + 'status': self.status, + 'err_message': self.err_message + } diff --git a/apps/application/flow/step_node/reranker_node/__init__.py b/apps/application/flow/step_node/reranker_node/__init__.py new file mode 100644 index 00000000000..881d0f8a393 --- /dev/null +++ b/apps/application/flow/step_node/reranker_node/__init__.py @@ -0,0 +1,9 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: __init__.py + @date:2024/9/4 11:37 + @desc: +""" +from .impl import * diff --git a/apps/application/flow/step_node/reranker_node/i_reranker_node.py b/apps/application/flow/step_node/reranker_node/i_reranker_node.py new file mode 100644 index 00000000000..3b95e4dd632 --- /dev/null +++ b/apps/application/flow/step_node/reranker_node/i_reranker_node.py @@ -0,0 +1,60 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: i_reranker_node.py + @date:2024/9/4 10:40 + @desc: +""" +from typing import Type + +from rest_framework import serializers + +from application.flow.i_step_node import INode, NodeResult +from common.util.field_message import ErrMessage +from django.utils.translation import gettext_lazy as _ + + +class RerankerSettingSerializer(serializers.Serializer): + # 需要查询的条数 + top_n = serializers.IntegerField(required=True, + error_messages=ErrMessage.integer(_("Reference segment number"))) + # 相似度 0-1之间 + similarity = serializers.FloatField(required=True, max_value=2, min_value=0, + error_messages=ErrMessage.float(_("Reference segment number"))) + max_paragraph_char_number = serializers.IntegerField(required=True, + error_messages=ErrMessage.float(_("Maximum number of words in a quoted segment"))) + + +class RerankerStepNodeSerializer(serializers.Serializer): + reranker_setting = RerankerSettingSerializer(required=True) + + question_reference_address = serializers.ListField(required=True) + reranker_model_id = serializers.UUIDField(required=True) + reranker_reference_list = serializers.ListField(required=True, child=serializers.ListField(required=True)) + + def is_valid(self, *, raise_exception=False): + super().is_valid(raise_exception=True) + + +class IRerankerNode(INode): + type = 'reranker-node' + + def get_node_params_serializer_class(self) -> Type[serializers.Serializer]: + return RerankerStepNodeSerializer + + def _run(self): + question = self.workflow_manage.get_reference_field( + self.node_params_serializer.data.get('question_reference_address')[0], + self.node_params_serializer.data.get('question_reference_address')[1:]) + reranker_list = [self.workflow_manage.get_reference_field( + reference[0], + reference[1:]) for reference in + self.node_params_serializer.data.get('reranker_reference_list')] + return self.execute(**self.node_params_serializer.data, question=str(question), + + reranker_list=reranker_list) + + def execute(self, question, reranker_setting, reranker_list, reranker_model_id, + **kwargs) -> NodeResult: + pass diff --git a/apps/application/flow/step_node/reranker_node/impl/__init__.py b/apps/application/flow/step_node/reranker_node/impl/__init__.py new file mode 100644 index 00000000000..ef5ca80585b --- /dev/null +++ b/apps/application/flow/step_node/reranker_node/impl/__init__.py @@ -0,0 +1,9 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: __init__.py + @date:2024/9/4 11:39 + @desc: +""" +from .base_reranker_node import * diff --git a/apps/application/flow/step_node/reranker_node/impl/base_reranker_node.py b/apps/application/flow/step_node/reranker_node/impl/base_reranker_node.py new file mode 100644 index 00000000000..ee92b88a52c --- /dev/null +++ b/apps/application/flow/step_node/reranker_node/impl/base_reranker_node.py @@ -0,0 +1,106 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: base_reranker_node.py + @date:2024/9/4 11:41 + @desc: +""" +from typing import List + +from langchain_core.documents import Document + +from application.flow.i_step_node import NodeResult +from application.flow.step_node.reranker_node.i_reranker_node import IRerankerNode +from setting.models_provider.tools import get_model_instance_by_model_user_id + + +def merge_reranker_list(reranker_list, result=None): + if result is None: + result = [] + for document in reranker_list: + if isinstance(document, list): + merge_reranker_list(document, result) + elif isinstance(document, dict): + content = document.get('title', '') + document.get('content', '') + title = document.get("title") + dataset_name = document.get("dataset_name") + document_name = document.get('document_name') + result.append( + Document(page_content=str(document) if len(content) == 0 else content, + metadata={'title': title, 'dataset_name': dataset_name, 'document_name': document_name})) + else: + result.append(Document(page_content=str(document), metadata={})) + return result + + +def filter_result(document_list: List[Document], max_paragraph_char_number, top_n, similarity): + use_len = 0 + result = [] + for index in range(len(document_list)): + document = document_list[index] + if use_len >= max_paragraph_char_number or index >= top_n or document.metadata.get( + 'relevance_score') < similarity: + break + content = document.page_content[0:max_paragraph_char_number - use_len] + use_len = use_len + len(content) + result.append({'page_content': content, 'metadata': document.metadata}) + return result + + +def reset_result_list(result_list: List[Document], document_list: List[Document]): + r = [] + document_list = document_list.copy() + for result in result_list: + filter_result_list = [document for document in document_list if document.page_content == result.page_content] + if len(filter_result_list) > 0: + item = filter_result_list[0] + document_list.remove(item) + r.append(Document(page_content=item.page_content, + metadata={**item.metadata, 'relevance_score': result.metadata.get('relevance_score')})) + else: + r.append(result) + return r + + +class BaseRerankerNode(IRerankerNode): + def save_context(self, details, workflow_manage): + self.context['document_list'] = details.get('document_list', []) + self.context['question'] = details.get('question') + self.context['run_time'] = details.get('run_time') + self.context['result_list'] = details.get('result_list') + self.context['result'] = details.get('result') + + def execute(self, question, reranker_setting, reranker_list, reranker_model_id, + **kwargs) -> NodeResult: + documents = merge_reranker_list(reranker_list) + top_n = reranker_setting.get('top_n', 3) + self.context['document_list'] = [{'page_content': document.page_content, 'metadata': document.metadata} for + document in documents] + self.context['question'] = question + reranker_model = get_model_instance_by_model_user_id(reranker_model_id, + self.flow_params_serializer.data.get('user_id'), + top_n=top_n) + result = reranker_model.compress_documents( + documents, + question) + similarity = reranker_setting.get('similarity', 0.6) + max_paragraph_char_number = reranker_setting.get('max_paragraph_char_number', 5000) + result = reset_result_list(result, documents) + r = filter_result(result, max_paragraph_char_number, top_n, similarity) + return NodeResult({'result_list': r, 'result': ''.join([item.get('page_content') for item in r])}, {}) + + def get_details(self, index: int, **kwargs): + return { + 'name': self.node.properties.get('stepName'), + "index": index, + 'document_list': self.context.get('document_list'), + "question": self.context.get('question'), + 'run_time': self.context.get('run_time'), + 'type': self.node.type, + 'reranker_setting': self.node_params_serializer.data.get('reranker_setting'), + 'result_list': self.context.get('result_list'), + 'result': self.context.get('result'), + 'status': self.status, + 'err_message': self.err_message + } diff --git a/apps/application/flow/step_node/search_dataset_node/__init__.py b/apps/application/flow/step_node/search_dataset_node/__init__.py new file mode 100644 index 00000000000..98a1afcd904 --- /dev/null +++ b/apps/application/flow/step_node/search_dataset_node/__init__.py @@ -0,0 +1,9 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: __init__.py + @date:2024/6/11 15:30 + @desc: +""" +from .impl import * diff --git a/apps/application/flow/step_node/search_dataset_node/i_search_dataset_node.py b/apps/application/flow/step_node/search_dataset_node/i_search_dataset_node.py new file mode 100644 index 00000000000..8f15c7a3203 --- /dev/null +++ b/apps/application/flow/step_node/search_dataset_node/i_search_dataset_node.py @@ -0,0 +1,79 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: i_search_dataset_node.py + @date:2024/6/3 17:52 + @desc: +""" +import re +from typing import Type + +from django.core import validators +from rest_framework import serializers + +from application.flow.i_step_node import INode, NodeResult +from common.util.common import flat_map +from common.util.field_message import ErrMessage +from django.utils.translation import gettext_lazy as _ + + +class DatasetSettingSerializer(serializers.Serializer): + # 需要查询的条数 + top_n = serializers.IntegerField(required=True, + error_messages=ErrMessage.integer(_("Reference segment number"))) + # 相似度 0-1之间 + similarity = serializers.FloatField(required=True, max_value=2, min_value=0, + error_messages=ErrMessage.float(_('similarity'))) + search_mode = serializers.CharField(required=True, validators=[ + validators.RegexValidator(regex=re.compile("^embedding|keywords|blend$"), + message=_("The type only supports embedding|keywords|blend"), code=500) + ], error_messages=ErrMessage.char(_("Retrieval Mode"))) + max_paragraph_char_number = serializers.IntegerField(required=True, + error_messages=ErrMessage.float(_("Maximum number of words in a quoted segment"))) + + +class SearchDatasetStepNodeSerializer(serializers.Serializer): + # 需要查询的数据集id列表 + dataset_id_list = serializers.ListField(required=True, child=serializers.UUIDField(required=True), + error_messages=ErrMessage.list(_("Dataset id list"))) + dataset_setting = DatasetSettingSerializer(required=True) + + question_reference_address = serializers.ListField(required=True) + + def is_valid(self, *, raise_exception=False): + super().is_valid(raise_exception=True) + + +def get_paragraph_list(chat_record, node_id): + return flat_map([chat_record.details[key].get('paragraph_list', []) for key in chat_record.details if + (chat_record.details[ + key].get('type', '') == 'search-dataset-node') and chat_record.details[key].get( + 'paragraph_list', []) is not None and key == node_id]) + + +class ISearchDatasetStepNode(INode): + type = 'search-dataset-node' + + def get_node_params_serializer_class(self) -> Type[serializers.Serializer]: + return SearchDatasetStepNodeSerializer + + def _run(self): + question = self.workflow_manage.get_reference_field( + self.node_params_serializer.data.get('question_reference_address')[0], + self.node_params_serializer.data.get('question_reference_address')[1:]) + exclude_paragraph_id_list = [] + if self.flow_params_serializer.data.get('re_chat', False): + history_chat_record = self.flow_params_serializer.data.get('history_chat_record', []) + paragraph_id_list = [p.get('id') for p in flat_map( + [get_paragraph_list(chat_record, self.runtime_node_id) for chat_record in history_chat_record if + chat_record.problem_text == question])] + exclude_paragraph_id_list = list(set(paragraph_id_list)) + + return self.execute(**self.node_params_serializer.data, question=str(question), + exclude_paragraph_id_list=exclude_paragraph_id_list) + + def execute(self, dataset_id_list, dataset_setting, question, + exclude_paragraph_id_list=None, + **kwargs) -> NodeResult: + pass diff --git a/apps/application/flow/step_node/search_dataset_node/impl/__init__.py b/apps/application/flow/step_node/search_dataset_node/impl/__init__.py new file mode 100644 index 00000000000..a9cff0d0941 --- /dev/null +++ b/apps/application/flow/step_node/search_dataset_node/impl/__init__.py @@ -0,0 +1,9 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: __init__.py + @date:2024/6/11 15:35 + @desc: +""" +from .base_search_dataset_node import BaseSearchDatasetNode diff --git a/apps/application/flow/step_node/search_dataset_node/impl/base_search_dataset_node.py b/apps/application/flow/step_node/search_dataset_node/impl/base_search_dataset_node.py new file mode 100644 index 00000000000..5107d4ce2c8 --- /dev/null +++ b/apps/application/flow/step_node/search_dataset_node/impl/base_search_dataset_node.py @@ -0,0 +1,146 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: base_search_dataset_node.py + @date:2024/6/4 11:56 + @desc: +""" +import os +from typing import List, Dict + +from django.db.models import QuerySet +from django.db import connection +from application.flow.i_step_node import NodeResult +from application.flow.step_node.search_dataset_node.i_search_dataset_node import ISearchDatasetStepNode +from common.config.embedding_config import VectorStore +from common.db.search import native_search +from common.util.file_util import get_file_content +from dataset.models import Document, Paragraph, DataSet +from embedding.models import SearchMode +from setting.models_provider.tools import get_model_instance_by_model_user_id +from smartdoc.conf import PROJECT_DIR + + +def get_embedding_id(dataset_id_list): + dataset_list = QuerySet(DataSet).filter(id__in=dataset_id_list) + if len(set([dataset.embedding_mode_id for dataset in dataset_list])) > 1: + raise Exception("关联知识库的向量模型不一致,无法召回分段。") + if len(dataset_list) == 0: + raise Exception("知识库设置错误,请重新设置知识库") + return dataset_list[0].embedding_mode_id + + +def get_none_result(question): + return NodeResult( + {'paragraph_list': [], 'is_hit_handling_method': [], 'question': question, 'data': '', + 'directly_return': ''}, {}) + + +def reset_title(title): + if title is None or len(title.strip()) == 0: + return "" + else: + return f"#### {title}\n" + + +class BaseSearchDatasetNode(ISearchDatasetStepNode): + def save_context(self, details, workflow_manage): + result = details.get('paragraph_list', []) + dataset_setting = self.node_params_serializer.data.get('dataset_setting') + directly_return = '\n'.join( + [f"{paragraph.get('title', '')}:{paragraph.get('content')}" for paragraph in result if + paragraph.get('is_hit_handling_method')]) + self.context['paragraph_list'] = result + self.context['question'] = details.get('question') + self.context['run_time'] = details.get('run_time') + self.context['is_hit_handling_method_list'] = [row for row in result if row.get('is_hit_handling_method')] + self.context['data'] = '\n'.join( + [f"{paragraph.get('title', '')}:{paragraph.get('content')}" for paragraph in + result])[0:dataset_setting.get('max_paragraph_char_number', 5000)] + self.context['directly_return'] = directly_return + + def execute(self, dataset_id_list, dataset_setting, question, + exclude_paragraph_id_list=None, + **kwargs) -> NodeResult: + self.context['question'] = question + if len(dataset_id_list) == 0: + return get_none_result(question) + model_id = get_embedding_id(dataset_id_list) + embedding_model = get_model_instance_by_model_user_id(model_id, self.flow_params_serializer.data.get('user_id')) + embedding_value = embedding_model.embed_query(question) + vector = VectorStore.get_embedding_vector() + exclude_document_id_list = [str(document.id) for document in + QuerySet(Document).filter( + dataset_id__in=dataset_id_list, + is_active=False)] + embedding_list = vector.query(question, embedding_value, dataset_id_list, exclude_document_id_list, + exclude_paragraph_id_list, True, dataset_setting.get('top_n'), + dataset_setting.get('similarity'), SearchMode(dataset_setting.get('search_mode'))) + # 手动关闭数据库连接 + connection.close() + if embedding_list is None: + return get_none_result(question) + paragraph_list = self.list_paragraph(embedding_list, vector) + result = [self.reset_paragraph(paragraph, embedding_list) for paragraph in paragraph_list] + result = sorted(result, key=lambda p: p.get('similarity'), reverse=True) + return NodeResult({'paragraph_list': result, + 'is_hit_handling_method_list': [row for row in result if row.get('is_hit_handling_method')], + 'data': '\n'.join( + [f"{reset_title(paragraph.get('title', ''))}{paragraph.get('content')}" for paragraph in + result])[0:dataset_setting.get('max_paragraph_char_number', 5000)], + 'directly_return': '\n'.join( + [paragraph.get('content') for paragraph in + result if + paragraph.get('is_hit_handling_method')]), + 'question': question}, + + {}) + + @staticmethod + def reset_paragraph(paragraph: Dict, embedding_list: List): + filter_embedding_list = [embedding for embedding in embedding_list if + str(embedding.get('paragraph_id')) == str(paragraph.get('id'))] + if filter_embedding_list is not None and len(filter_embedding_list) > 0: + find_embedding = filter_embedding_list[-1] + return { + **paragraph, + 'similarity': find_embedding.get('similarity'), + 'is_hit_handling_method': find_embedding.get('similarity') > paragraph.get( + 'directly_return_similarity') and paragraph.get('hit_handling_method') == 'directly_return', + 'update_time': paragraph.get('update_time').strftime("%Y-%m-%d %H:%M:%S"), + 'create_time': paragraph.get('create_time').strftime("%Y-%m-%d %H:%M:%S"), + 'id': str(paragraph.get('id')), + 'dataset_id': str(paragraph.get('dataset_id')), + 'document_id': str(paragraph.get('document_id')) + } + + @staticmethod + def list_paragraph(embedding_list: List, vector): + paragraph_id_list = [row.get('paragraph_id') for row in embedding_list] + if paragraph_id_list is None or len(paragraph_id_list) == 0: + return [] + paragraph_list = native_search(QuerySet(Paragraph).filter(id__in=paragraph_id_list), + get_file_content( + os.path.join(PROJECT_DIR, "apps", "application", 'sql', + 'list_dataset_paragraph_by_paragraph_id.sql')), + with_table_name=True) + # 如果向量库中存在脏数据 直接删除 + if len(paragraph_list) != len(paragraph_id_list): + exist_paragraph_list = [row.get('id') for row in paragraph_list] + for paragraph_id in paragraph_id_list: + if not exist_paragraph_list.__contains__(paragraph_id): + vector.delete_by_paragraph_id(paragraph_id) + return paragraph_list + + def get_details(self, index: int, **kwargs): + return { + 'name': self.node.properties.get('stepName'), + 'question': self.context.get('question'), + "index": index, + 'run_time': self.context.get('run_time'), + 'paragraph_list': self.context.get('paragraph_list'), + 'type': self.node.type, + 'status': self.status, + 'err_message': self.err_message + } diff --git a/apps/application/flow/step_node/speech_to_text_step_node/__init__.py b/apps/application/flow/step_node/speech_to_text_step_node/__init__.py new file mode 100644 index 00000000000..f3feecc9ce2 --- /dev/null +++ b/apps/application/flow/step_node/speech_to_text_step_node/__init__.py @@ -0,0 +1,3 @@ +# coding=utf-8 + +from .impl import * diff --git a/apps/application/flow/step_node/speech_to_text_step_node/i_speech_to_text_node.py b/apps/application/flow/step_node/speech_to_text_step_node/i_speech_to_text_node.py new file mode 100644 index 00000000000..154762dca1a --- /dev/null +++ b/apps/application/flow/step_node/speech_to_text_step_node/i_speech_to_text_node.py @@ -0,0 +1,38 @@ +# coding=utf-8 + +from typing import Type + +from rest_framework import serializers + +from application.flow.i_step_node import INode, NodeResult +from common.util.field_message import ErrMessage +from django.utils.translation import gettext_lazy as _ + + +class SpeechToTextNodeSerializer(serializers.Serializer): + stt_model_id = serializers.CharField(required=True, error_messages=ErrMessage.char(_("Model id"))) + + is_result = serializers.BooleanField(required=False, error_messages=ErrMessage.boolean(_('Whether to return content'))) + + audio_list = serializers.ListField(required=True, error_messages=ErrMessage.list(_("The audio file cannot be empty"))) + + +class ISpeechToTextNode(INode): + type = 'speech-to-text-node' + + def get_node_params_serializer_class(self) -> Type[serializers.Serializer]: + return SpeechToTextNodeSerializer + + def _run(self): + res = self.workflow_manage.get_reference_field(self.node_params_serializer.data.get('audio_list')[0], + self.node_params_serializer.data.get('audio_list')[1:]) + for audio in res: + if 'file_id' not in audio: + raise ValueError(_("Parameter value error: The uploaded audio lacks file_id, and the audio upload fails")) + + return self.execute(audio=res, **self.node_params_serializer.data, **self.flow_params_serializer.data) + + def execute(self, stt_model_id, chat_id, + audio, + **kwargs) -> NodeResult: + pass diff --git a/apps/application/flow/step_node/speech_to_text_step_node/impl/__init__.py b/apps/application/flow/step_node/speech_to_text_step_node/impl/__init__.py new file mode 100644 index 00000000000..9d2da615820 --- /dev/null +++ b/apps/application/flow/step_node/speech_to_text_step_node/impl/__init__.py @@ -0,0 +1,3 @@ +# coding=utf-8 + +from .base_speech_to_text_node import BaseSpeechToTextNode diff --git a/apps/application/flow/step_node/speech_to_text_step_node/impl/base_speech_to_text_node.py b/apps/application/flow/step_node/speech_to_text_step_node/impl/base_speech_to_text_node.py new file mode 100644 index 00000000000..13b954e4622 --- /dev/null +++ b/apps/application/flow/step_node/speech_to_text_step_node/impl/base_speech_to_text_node.py @@ -0,0 +1,72 @@ +# coding=utf-8 +import os +import tempfile +import time +import io +from typing import List, Dict + +from django.db.models import QuerySet +from pydub import AudioSegment +from concurrent.futures import ThreadPoolExecutor +from application.flow.i_step_node import NodeResult, INode +from application.flow.step_node.speech_to_text_step_node.i_speech_to_text_node import ISpeechToTextNode +from common.util.common import split_and_transcribe, any_to_mp3 +from dataset.models import File +from setting.models_provider.tools import get_model_instance_by_model_user_id + +class BaseSpeechToTextNode(ISpeechToTextNode): + + def save_context(self, details, workflow_manage): + self.context['answer'] = details.get('answer') + if self.node_params.get('is_result', False): + self.answer_text = details.get('answer') + + def execute(self, stt_model_id, chat_id, audio, **kwargs) -> NodeResult: + stt_model = get_model_instance_by_model_user_id(stt_model_id, self.flow_params_serializer.data.get('user_id')) + audio_list = audio + self.context['audio_list'] = audio + + def process_audio_item(audio_item, model): + file = QuerySet(File).filter(id=audio_item['file_id']).first() + # 根据file_name 吧文件转成mp3格式 + file_format = file.file_name.split('.')[-1] + with tempfile.NamedTemporaryFile(delete=False, suffix=f'.{file_format}') as temp_file: + temp_file.write(file.get_byte().tobytes()) + temp_file_path = temp_file.name + with tempfile.NamedTemporaryFile(delete=False, suffix='.mp3') as temp_amr_file: + temp_mp3_path = temp_amr_file.name + any_to_mp3(temp_file_path, temp_mp3_path) + try: + transcription = split_and_transcribe(temp_mp3_path, model) + return {file.file_name: transcription} + finally: + os.remove(temp_file_path) + os.remove(temp_mp3_path) + + def process_audio_items(audio_list, model): + with ThreadPoolExecutor(max_workers=5) as executor: + results = list(executor.map(lambda item: process_audio_item(item, model), audio_list)) + return results + + result = process_audio_items(audio_list, stt_model) + content = [] + result_content = [] + for item in result: + for key, value in item.items(): + content.append(f'### {key}\n{value}') + result_content.append(value) + return NodeResult({'answer': '\n'.join(result_content), 'result': '\n'.join(result_content), + 'content': content}, {}) + + def get_details(self, index: int, **kwargs): + return { + 'name': self.node.properties.get('stepName'), + "index": index, + 'run_time': self.context.get('run_time'), + 'answer': self.context.get('answer'), + 'content': self.context.get('content'), + 'type': self.node.type, + 'status': self.status, + 'err_message': self.err_message, + 'audio_list': self.context.get('audio_list'), + } diff --git a/apps/application/flow/step_node/start_node/__init__.py b/apps/application/flow/step_node/start_node/__init__.py new file mode 100644 index 00000000000..98a1afcd904 --- /dev/null +++ b/apps/application/flow/step_node/start_node/__init__.py @@ -0,0 +1,9 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: __init__.py + @date:2024/6/11 15:30 + @desc: +""" +from .impl import * diff --git a/apps/application/flow/step_node/start_node/i_start_node.py b/apps/application/flow/step_node/start_node/i_start_node.py new file mode 100644 index 00000000000..41d73f21811 --- /dev/null +++ b/apps/application/flow/step_node/start_node/i_start_node.py @@ -0,0 +1,20 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: i_start_node.py + @date:2024/6/3 16:54 + @desc: +""" + +from application.flow.i_step_node import INode, NodeResult + + +class IStarNode(INode): + type = 'start-node' + + def _run(self): + return self.execute(**self.flow_params_serializer.data) + + def execute(self, question, **kwargs) -> NodeResult: + pass diff --git a/apps/application/flow/step_node/start_node/impl/__init__.py b/apps/application/flow/step_node/start_node/impl/__init__.py new file mode 100644 index 00000000000..b68a92d021f --- /dev/null +++ b/apps/application/flow/step_node/start_node/impl/__init__.py @@ -0,0 +1,9 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: __init__.py + @date:2024/6/11 15:36 + @desc: +""" +from .base_start_node import BaseStartStepNode diff --git a/apps/application/flow/step_node/start_node/impl/base_start_node.py b/apps/application/flow/step_node/start_node/impl/base_start_node.py new file mode 100644 index 00000000000..24b9684714e --- /dev/null +++ b/apps/application/flow/step_node/start_node/impl/base_start_node.py @@ -0,0 +1,92 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: base_start_node.py + @date:2024/6/3 17:17 + @desc: +""" +import time +from datetime import datetime +from typing import List, Type + +from rest_framework import serializers + +from application.flow.i_step_node import NodeResult +from application.flow.step_node.start_node.i_start_node import IStarNode + + +def get_default_global_variable(input_field_list: List): + return {item.get('variable'): item.get('default_value') for item in input_field_list if + item.get('default_value', None) is not None} + + +def get_global_variable(node): + history_chat_record = node.flow_params_serializer.data.get('history_chat_record', []) + history_context = [{'question': chat_record.problem_text, 'answer': chat_record.answer_text} for chat_record in + history_chat_record] + chat_id = node.flow_params_serializer.data.get('chat_id') + return {'time': datetime.now().strftime('%Y-%m-%d %H:%M:%S'), 'start_time': time.time(), + 'history_context': history_context, 'chat_id': str(chat_id), **node.workflow_manage.form_data} + + +class BaseStartStepNode(IStarNode): + def save_context(self, details, workflow_manage): + base_node = self.workflow_manage.get_base_node() + default_global_variable = get_default_global_variable(base_node.properties.get('input_field_list', [])) + workflow_variable = {**default_global_variable, **get_global_variable(self)} + self.context['question'] = details.get('question') + self.context['run_time'] = details.get('run_time') + self.context['document'] = details.get('document_list') + self.context['image'] = details.get('image_list') + self.context['audio'] = details.get('audio_list') + self.context['other'] = details.get('other_list') + self.status = details.get('status') + self.err_message = details.get('err_message') + for key, value in workflow_variable.items(): + workflow_manage.context[key] = value + for item in details.get('global_fields', []): + workflow_manage.context[item.get('key')] = item.get('value') + + def get_node_params_serializer_class(self) -> Type[serializers.Serializer]: + pass + + def execute(self, question, **kwargs) -> NodeResult: + base_node = self.workflow_manage.get_base_node() + default_global_variable = get_default_global_variable(base_node.properties.get('input_field_list', [])) + workflow_variable = {**default_global_variable, **get_global_variable(self)} + """ + 开始节点 初始化全局变量 + """ + node_variable = { + 'question': question, + 'image': self.workflow_manage.image_list, + 'document': self.workflow_manage.document_list, + 'audio': self.workflow_manage.audio_list, + 'other': self.workflow_manage.other_list, + } + return NodeResult(node_variable, workflow_variable) + + def get_details(self, index: int, **kwargs): + global_fields = [] + for field in self.node.properties.get('config')['globalFields']: + key = field['value'] + global_fields.append({ + 'label': field['label'], + 'key': key, + 'value': self.workflow_manage.context[key] if key in self.workflow_manage.context else '' + }) + return { + 'name': self.node.properties.get('stepName'), + "index": index, + "question": self.context.get('question'), + 'run_time': self.context.get('run_time'), + 'type': self.node.type, + 'status': self.status, + 'err_message': self.err_message, + 'image_list': self.context.get('image'), + 'document_list': self.context.get('document'), + 'audio_list': self.context.get('audio'), + 'other_list': self.context.get('other'), + 'global_fields': global_fields + } diff --git a/apps/application/flow/step_node/text_to_speech_step_node/__init__.py b/apps/application/flow/step_node/text_to_speech_step_node/__init__.py new file mode 100644 index 00000000000..f3feecc9ce2 --- /dev/null +++ b/apps/application/flow/step_node/text_to_speech_step_node/__init__.py @@ -0,0 +1,3 @@ +# coding=utf-8 + +from .impl import * diff --git a/apps/application/flow/step_node/text_to_speech_step_node/i_text_to_speech_node.py b/apps/application/flow/step_node/text_to_speech_step_node/i_text_to_speech_node.py new file mode 100644 index 00000000000..68b53ea92db --- /dev/null +++ b/apps/application/flow/step_node/text_to_speech_step_node/i_text_to_speech_node.py @@ -0,0 +1,36 @@ +# coding=utf-8 + +from typing import Type + +from rest_framework import serializers + +from application.flow.i_step_node import INode, NodeResult +from common.util.field_message import ErrMessage +from django.utils.translation import gettext_lazy as _ + + +class TextToSpeechNodeSerializer(serializers.Serializer): + tts_model_id = serializers.CharField(required=True, error_messages=ErrMessage.char(_("Model id"))) + + is_result = serializers.BooleanField(required=False, error_messages=ErrMessage.boolean(_('Whether to return content'))) + + content_list = serializers.ListField(required=True, error_messages=ErrMessage.list(_("Text content"))) + model_params_setting = serializers.DictField(required=False, + error_messages=ErrMessage.integer(_("Model parameter settings"))) + + +class ITextToSpeechNode(INode): + type = 'text-to-speech-node' + + def get_node_params_serializer_class(self) -> Type[serializers.Serializer]: + return TextToSpeechNodeSerializer + + def _run(self): + content = self.workflow_manage.get_reference_field(self.node_params_serializer.data.get('content_list')[0], + self.node_params_serializer.data.get('content_list')[1:]) + return self.execute(content=content, **self.node_params_serializer.data, **self.flow_params_serializer.data) + + def execute(self, tts_model_id, chat_id, + content, model_params_setting=None, + **kwargs) -> NodeResult: + pass diff --git a/apps/application/flow/step_node/text_to_speech_step_node/impl/__init__.py b/apps/application/flow/step_node/text_to_speech_step_node/impl/__init__.py new file mode 100644 index 00000000000..385b9718f6e --- /dev/null +++ b/apps/application/flow/step_node/text_to_speech_step_node/impl/__init__.py @@ -0,0 +1,3 @@ +# coding=utf-8 + +from .base_text_to_speech_node import BaseTextToSpeechNode diff --git a/apps/application/flow/step_node/text_to_speech_step_node/impl/base_text_to_speech_node.py b/apps/application/flow/step_node/text_to_speech_step_node/impl/base_text_to_speech_node.py new file mode 100644 index 00000000000..97044729596 --- /dev/null +++ b/apps/application/flow/step_node/text_to_speech_step_node/impl/base_text_to_speech_node.py @@ -0,0 +1,76 @@ +# coding=utf-8 +import io +import mimetypes + +from django.core.files.uploadedfile import InMemoryUploadedFile + +from application.flow.i_step_node import NodeResult, INode +from application.flow.step_node.image_understand_step_node.i_image_understand_node import IImageUnderstandNode +from application.flow.step_node.text_to_speech_step_node.i_text_to_speech_node import ITextToSpeechNode +from dataset.models import File +from dataset.serializers.file_serializers import FileSerializer +from setting.models_provider.tools import get_model_instance_by_model_user_id + + +def bytes_to_uploaded_file(file_bytes, file_name="generated_audio.mp3"): + content_type, _ = mimetypes.guess_type(file_name) + if content_type is None: + # 如果未能识别,设置为默认的二进制文件类型 + content_type = "application/octet-stream" + # 创建一个内存中的字节流对象 + file_stream = io.BytesIO(file_bytes) + + # 获取文件大小 + file_size = len(file_bytes) + + uploaded_file = InMemoryUploadedFile( + file=file_stream, + field_name=None, + name=file_name, + content_type=content_type, + size=file_size, + charset=None, + ) + return uploaded_file + + +class BaseTextToSpeechNode(ITextToSpeechNode): + def save_context(self, details, workflow_manage): + self.context['answer'] = details.get('answer') + if self.node_params.get('is_result', False): + self.answer_text = details.get('answer') + + def execute(self, tts_model_id, chat_id, + content, model_params_setting=None, + **kwargs) -> NodeResult: + self.context['content'] = content + model = get_model_instance_by_model_user_id(tts_model_id, self.flow_params_serializer.data.get('user_id'), + **model_params_setting) + audio_byte = model.text_to_speech(content) + # 需要把这个音频文件存储到数据库中 + file_name = 'generated_audio.mp3' + file = bytes_to_uploaded_file(audio_byte, file_name) + application = self.workflow_manage.work_flow_post_handler.chat_info.application + meta = { + 'debug': False if application.id else True, + 'chat_id': chat_id, + 'application_id': str(application.id) if application.id else None, + } + file_url = FileSerializer(data={'file': file, 'meta': meta}).upload() + # 拼接一个audio标签的src属性 + audio_label = f'' + file_id = file_url.split('/')[-1] + audio_list = [{'file_id': file_id, 'file_name': file_name, 'url': file_url}] + return NodeResult({'answer': audio_label, 'result': audio_list}, {}) + + def get_details(self, index: int, **kwargs): + return { + 'name': self.node.properties.get('stepName'), + "index": index, + 'run_time': self.context.get('run_time'), + 'type': self.node.type, + 'status': self.status, + 'content': self.context.get('content'), + 'err_message': self.err_message, + 'answer': self.context.get('answer'), + } diff --git a/apps/application/flow/step_node/variable_assign_node/__init__.py b/apps/application/flow/step_node/variable_assign_node/__init__.py new file mode 100644 index 00000000000..2d231e6066d --- /dev/null +++ b/apps/application/flow/step_node/variable_assign_node/__init__.py @@ -0,0 +1,3 @@ +# coding=utf-8 + +from .impl import * \ No newline at end of file diff --git a/apps/application/flow/step_node/variable_assign_node/i_variable_assign_node.py b/apps/application/flow/step_node/variable_assign_node/i_variable_assign_node.py new file mode 100644 index 00000000000..e4594183f35 --- /dev/null +++ b/apps/application/flow/step_node/variable_assign_node/i_variable_assign_node.py @@ -0,0 +1,27 @@ +# coding=utf-8 + +from typing import Type + +from django.utils.translation import gettext_lazy as _ +from rest_framework import serializers + +from application.flow.i_step_node import INode, NodeResult +from common.util.field_message import ErrMessage + + +class VariableAssignNodeParamsSerializer(serializers.Serializer): + variable_list = serializers.ListField(required=True, + error_messages=ErrMessage.list(_("Reference Field"))) + + +class IVariableAssignNode(INode): + type = 'variable-assign-node' + + def get_node_params_serializer_class(self) -> Type[serializers.Serializer]: + return VariableAssignNodeParamsSerializer + + def _run(self): + return self.execute(**self.node_params_serializer.data, **self.flow_params_serializer.data) + + def execute(self, variable_list, stream, **kwargs) -> NodeResult: + pass diff --git a/apps/application/flow/step_node/variable_assign_node/impl/__init__.py b/apps/application/flow/step_node/variable_assign_node/impl/__init__.py new file mode 100644 index 00000000000..7585cdd8fe4 --- /dev/null +++ b/apps/application/flow/step_node/variable_assign_node/impl/__init__.py @@ -0,0 +1,9 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: __init__.py + @date:2024/6/11 17:49 + @desc: +""" +from .base_variable_assign_node import * \ No newline at end of file diff --git a/apps/application/flow/step_node/variable_assign_node/impl/base_variable_assign_node.py b/apps/application/flow/step_node/variable_assign_node/impl/base_variable_assign_node.py new file mode 100644 index 00000000000..ce2906e6293 --- /dev/null +++ b/apps/application/flow/step_node/variable_assign_node/impl/base_variable_assign_node.py @@ -0,0 +1,65 @@ +# coding=utf-8 +import json +from typing import List + +from application.flow.i_step_node import NodeResult +from application.flow.step_node.variable_assign_node.i_variable_assign_node import IVariableAssignNode + + +class BaseVariableAssignNode(IVariableAssignNode): + def save_context(self, details, workflow_manage): + self.context['variable_list'] = details.get('variable_list') + self.context['result_list'] = details.get('result_list') + + def execute(self, variable_list, stream, **kwargs) -> NodeResult: + # + result_list = [] + for variable in variable_list: + if 'fields' not in variable: + continue + if 'global' == variable['fields'][0]: + result = { + 'name': variable['name'], + 'input_value': self.get_reference_content(variable['fields']), + } + if variable['source'] == 'custom': + if variable['type'] == 'json': + if isinstance(variable['value'], dict) or isinstance(variable['value'], list): + val = variable['value'] + else: + val = json.loads(variable['value']) + self.workflow_manage.context[variable['fields'][1]] = val + result['output_value'] = variable['value'] = val + elif variable['type'] == 'string': + # 变量解析 例如:{{global.xxx}} + val = self.workflow_manage.generate_prompt(variable['value']) + self.workflow_manage.context[variable['fields'][1]] = val + result['output_value'] = val + else: + val = variable['value'] + self.workflow_manage.context[variable['fields'][1]] = val + result['output_value'] = val + else: + reference = self.get_reference_content(variable['reference']) + self.workflow_manage.context[variable['fields'][1]] = reference + result['output_value'] = reference + result_list.append(result) + + return NodeResult({'variable_list': variable_list, 'result_list': result_list}, {}) + + def get_reference_content(self, fields: List[str]): + return str(self.workflow_manage.get_reference_field( + fields[0], + fields[1:])) + + def get_details(self, index: int, **kwargs): + return { + 'name': self.node.properties.get('stepName'), + "index": index, + 'run_time': self.context.get('run_time'), + 'type': self.node.type, + 'variable_list': self.context.get('variable_list'), + 'result_list': self.context.get('result_list'), + 'status': self.status, + 'err_message': self.err_message + } diff --git a/apps/application/flow/tools.py b/apps/application/flow/tools.py new file mode 100644 index 00000000000..dfbf69b3593 --- /dev/null +++ b/apps/application/flow/tools.py @@ -0,0 +1,191 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: utils.py + @date:2024/6/6 15:15 + @desc: +""" +import json +from typing import Iterator + +from django.http import StreamingHttpResponse +from langchain_core.messages import BaseMessageChunk, BaseMessage + +from application.flow.i_step_node import WorkFlowPostHandler +from common.response import result + + +class Reasoning: + def __init__(self, reasoning_content_start, reasoning_content_end): + self.content = "" + self.reasoning_content = "" + self.all_content = "" + self.reasoning_content_start_tag = reasoning_content_start + self.reasoning_content_end_tag = reasoning_content_end + self.reasoning_content_start_tag_len = len( + reasoning_content_start) if reasoning_content_start is not None else 0 + self.reasoning_content_end_tag_len = len(reasoning_content_end) if reasoning_content_end is not None else 0 + self.reasoning_content_end_tag_prefix = reasoning_content_end[ + 0] if self.reasoning_content_end_tag_len > 0 else '' + self.reasoning_content_is_start = False + self.reasoning_content_is_end = False + self.reasoning_content_chunk = "" + + def get_end_reasoning_content(self): + if not self.reasoning_content_is_start and not self.reasoning_content_is_end: + r = {'content': self.all_content, 'reasoning_content': ''} + self.reasoning_content_chunk = "" + return r + if self.reasoning_content_is_start and not self.reasoning_content_is_end: + r = {'content': '', 'reasoning_content': self.reasoning_content_chunk} + self.reasoning_content_chunk = "" + return r + return {'content': '', 'reasoning_content': ''} + + def get_reasoning_content(self, chunk): + # 如果没有开始思考过程标签那么就全是结果 + if self.reasoning_content_start_tag is None or len(self.reasoning_content_start_tag) == 0: + self.content += chunk.content + return {'content': chunk.content, 'reasoning_content': ''} + # 如果没有结束思考过程标签那么就全部是思考过程 + if self.reasoning_content_end_tag is None or len(self.reasoning_content_end_tag) == 0: + return {'content': '', 'reasoning_content': chunk.content} + self.all_content += chunk.content + if not self.reasoning_content_is_start and len(self.all_content) >= self.reasoning_content_start_tag_len: + if self.all_content.startswith(self.reasoning_content_start_tag): + self.reasoning_content_is_start = True + self.reasoning_content_chunk = self.all_content[self.reasoning_content_start_tag_len:] + else: + if not self.reasoning_content_is_end: + self.reasoning_content_is_end = True + self.content += self.all_content + return {'content': self.all_content, 'reasoning_content': ''} + else: + if self.reasoning_content_is_start: + self.reasoning_content_chunk += chunk.content + reasoning_content_end_tag_prefix_index = self.reasoning_content_chunk.find( + self.reasoning_content_end_tag_prefix) + if self.reasoning_content_is_end: + self.content += chunk.content + return {'content': chunk.content, 'reasoning_content': ''} + # 是否包含结束 + if reasoning_content_end_tag_prefix_index > -1: + if len(self.reasoning_content_chunk) - reasoning_content_end_tag_prefix_index >= self.reasoning_content_end_tag_len: + reasoning_content_end_tag_index = self.reasoning_content_chunk.find(self.reasoning_content_end_tag) + if reasoning_content_end_tag_index > -1: + reasoning_content_chunk = self.reasoning_content_chunk[0:reasoning_content_end_tag_index] + content_chunk = self.reasoning_content_chunk[ + reasoning_content_end_tag_index + self.reasoning_content_end_tag_len:] + self.reasoning_content += reasoning_content_chunk + self.content += content_chunk + self.reasoning_content_chunk = "" + self.reasoning_content_is_end = True + return {'content': content_chunk, 'reasoning_content': reasoning_content_chunk} + else: + reasoning_content_chunk = self.reasoning_content_chunk[0:reasoning_content_end_tag_prefix_index + 1] + self.reasoning_content_chunk = self.reasoning_content_chunk.replace(reasoning_content_chunk, '') + self.reasoning_content += reasoning_content_chunk + return {'content': '', 'reasoning_content': reasoning_content_chunk} + else: + return {'content': '', 'reasoning_content': ''} + + else: + if self.reasoning_content_is_end: + self.content += chunk.content + return {'content': chunk.content, 'reasoning_content': ''} + else: + # aaa + result = {'content': '', 'reasoning_content': self.reasoning_content_chunk} + self.reasoning_content += self.reasoning_content_chunk + self.reasoning_content_chunk = "" + return result + + +def event_content(chat_id, chat_record_id, response, workflow, + write_context, + post_handler: WorkFlowPostHandler): + """ + 用于处理流式输出 + @param chat_id: 会话id + @param chat_record_id: 对话记录id + @param response: 响应数据 + @param workflow: 工作流管理器 + @param write_context 写入节点上下文 + @param post_handler: 后置处理器 + """ + answer = '' + try: + for chunk in response: + answer += chunk.content + yield 'data: ' + json.dumps({'chat_id': str(chat_id), 'id': str(chat_record_id), 'operate': True, + 'content': chunk.content, 'is_end': False}, ensure_ascii=False) + "\n\n" + write_context(answer, 200) + post_handler.handler(chat_id, chat_record_id, answer, workflow) + yield 'data: ' + json.dumps({'chat_id': str(chat_id), 'id': str(chat_record_id), 'operate': True, + 'content': '', 'is_end': True}, ensure_ascii=False) + "\n\n" + except Exception as e: + answer = str(e) + write_context(answer, 500) + post_handler.handler(chat_id, chat_record_id, answer, workflow) + yield 'data: ' + json.dumps({'chat_id': str(chat_id), 'id': str(chat_record_id), 'operate': True, + 'content': answer, 'is_end': True}, ensure_ascii=False) + "\n\n" + + +def to_stream_response(chat_id, chat_record_id, response: Iterator[BaseMessageChunk], workflow, write_context, + post_handler): + """ + 将结果转换为服务流输出 + @param chat_id: 会话id + @param chat_record_id: 对话记录id + @param response: 响应数据 + @param workflow: 工作流管理器 + @param write_context 写入节点上下文 + @param post_handler: 后置处理器 + @return: 响应 + """ + r = StreamingHttpResponse( + streaming_content=event_content(chat_id, chat_record_id, response, workflow, write_context, post_handler), + content_type='text/event-stream;charset=utf-8', + charset='utf-8') + + r['Cache-Control'] = 'no-cache' + return r + + +def to_response(chat_id, chat_record_id, response: BaseMessage, workflow, write_context, + post_handler: WorkFlowPostHandler): + """ + 将结果转换为服务输出 + + @param chat_id: 会话id + @param chat_record_id: 对话记录id + @param response: 响应数据 + @param workflow: 工作流管理器 + @param write_context 写入节点上下文 + @param post_handler: 后置处理器 + @return: 响应 + """ + answer = response.content + write_context(answer) + post_handler.handler(chat_id, chat_record_id, answer, workflow) + return result.success({'chat_id': str(chat_id), 'id': str(chat_record_id), 'operate': True, + 'content': answer, 'is_end': True}) + + +def to_response_simple(chat_id, chat_record_id, response: BaseMessage, workflow, + post_handler: WorkFlowPostHandler): + answer = response.content + post_handler.handler(chat_id, chat_record_id, answer, workflow) + return result.success({'chat_id': str(chat_id), 'id': str(chat_record_id), 'operate': True, + 'content': answer, 'is_end': True}) + + +def to_stream_response_simple(stream_event): + r = StreamingHttpResponse( + streaming_content=stream_event, + content_type='text/event-stream;charset=utf-8', + charset='utf-8') + + r['Cache-Control'] = 'no-cache' + return r diff --git a/apps/application/flow/workflow_manage.py b/apps/application/flow/workflow_manage.py new file mode 100644 index 00000000000..0f7bc9c7576 --- /dev/null +++ b/apps/application/flow/workflow_manage.py @@ -0,0 +1,827 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: workflow_manage.py + @date:2024/1/9 17:40 + @desc: +""" +import concurrent +import json +import threading +import traceback +from concurrent.futures import ThreadPoolExecutor +from functools import reduce +from typing import List, Dict + +from django.db import close_old_connections +from django.db.models import QuerySet +from django.utils import translation +from django.utils.translation import get_language +from django.utils.translation import gettext as _ +from langchain_core.prompts import PromptTemplate +from rest_framework import status +from rest_framework.exceptions import ErrorDetail, ValidationError + +from application.flow import tools +from application.flow.i_step_node import INode, WorkFlowPostHandler, NodeResult +from application.flow.step_node import get_node +from common.exception.app_exception import AppApiException +from common.handle.base_to_response import BaseToResponse +from common.handle.impl.response.system_to_response import SystemToResponse +from function_lib.models.function import FunctionLib +from setting.models import Model +from setting.models_provider import get_model_credential + +executor = ThreadPoolExecutor(max_workers=200) + + +class Edge: + def __init__(self, _id: str, _type: str, sourceNodeId: str, targetNodeId: str, **keywords): + self.id = _id + self.type = _type + self.sourceNodeId = sourceNodeId + self.targetNodeId = targetNodeId + for keyword in keywords: + self.__setattr__(keyword, keywords.get(keyword)) + + +class Node: + def __init__(self, _id: str, _type: str, x: int, y: int, properties: dict, **kwargs): + self.id = _id + self.type = _type + self.x = x + self.y = y + self.properties = properties + for keyword in kwargs: + self.__setattr__(keyword, kwargs.get(keyword)) + + +end_nodes = ['ai-chat-node', 'reply-node', 'function-node', 'function-lib-node', 'application-node', + 'image-understand-node', 'speech-to-text-node', 'text-to-speech-node', 'image-generate-node'] + + +class Flow: + def __init__(self, nodes: List[Node], edges: List[Edge]): + self.nodes = nodes + self.edges = edges + + @staticmethod + def new_instance(flow_obj: Dict): + nodes = flow_obj.get('nodes') + edges = flow_obj.get('edges') + nodes = [Node(node.get('id'), node.get('type'), **node) + for node in nodes] + edges = [Edge(edge.get('id'), edge.get('type'), **edge) for edge in edges] + return Flow(nodes, edges) + + def get_start_node(self): + start_node_list = [node for node in self.nodes if node.id == 'start-node'] + return start_node_list[0] + + def get_search_node(self): + return [node for node in self.nodes if node.type == 'search-dataset-node'] + + def is_valid(self): + """ + 校验工作流数据 + """ + self.is_valid_model_params() + self.is_valid_start_node() + self.is_valid_base_node() + self.is_valid_work_flow() + + @staticmethod + def is_valid_node_params(node: Node): + get_node(node.type)(node, None, None) + + def is_valid_node(self, node: Node): + self.is_valid_node_params(node) + if node.type == 'condition-node': + branch_list = node.properties.get('node_data').get('branch') + for branch in branch_list: + source_anchor_id = f"{node.id}_{branch.get('id')}_right" + edge_list = [edge for edge in self.edges if edge.sourceAnchorId == source_anchor_id] + if len(edge_list) == 0: + raise AppApiException(500, + _('The branch {branch} of the {node} node needs to be connected').format( + node=node.properties.get("stepName"), branch=branch.get("type"))) + + else: + edge_list = [edge for edge in self.edges if edge.sourceNodeId == node.id] + if len(edge_list) == 0 and not end_nodes.__contains__(node.type): + raise AppApiException(500, _("{node} Nodes cannot be considered as end nodes").format( + node=node.properties.get("stepName"))) + + def get_next_nodes(self, node: Node): + edge_list = [edge for edge in self.edges if edge.sourceNodeId == node.id] + node_list = reduce(lambda x, y: [*x, *y], + [[node for node in self.nodes if node.id == edge.targetNodeId] for edge in edge_list], + []) + if len(node_list) == 0 and not end_nodes.__contains__(node.type): + raise AppApiException(500, + _("The next node that does not exist")) + return node_list + + def is_valid_work_flow(self, up_node=None): + if up_node is None: + up_node = self.get_start_node() + self.is_valid_node(up_node) + next_nodes = self.get_next_nodes(up_node) + for next_node in next_nodes: + self.is_valid_work_flow(next_node) + + def is_valid_start_node(self): + start_node_list = [node for node in self.nodes if node.id == 'start-node'] + if len(start_node_list) == 0: + raise AppApiException(500, _('The starting node is required')) + if len(start_node_list) > 1: + raise AppApiException(500, _('There can only be one starting node')) + + def is_valid_model_params(self): + node_list = [node for node in self.nodes if (node.type == 'ai-chat-node' or node.type == 'question-node')] + for node in node_list: + model = QuerySet(Model).filter(id=node.properties.get('node_data', {}).get('model_id')).first() + if model is None: + raise ValidationError(ErrorDetail( + _('The node {node} model does not exist').format(node=node.properties.get("stepName")))) + credential = get_model_credential(model.provider, model.model_type, model.model_name) + model_params_setting = node.properties.get('node_data', {}).get('model_params_setting') + model_params_setting_form = credential.get_model_params_setting_form( + model.model_name) + if model_params_setting is None: + model_params_setting = model_params_setting_form.get_default_form_data() + node.properties.get('node_data', {})['model_params_setting'] = model_params_setting + if node.properties.get('status', 200) != 200: + raise ValidationError( + ErrorDetail(_("Node {node} is unavailable").format(node.properties.get("stepName")))) + node_list = [node for node in self.nodes if (node.type == 'function-lib-node')] + for node in node_list: + function_lib_id = node.properties.get('node_data', {}).get('function_lib_id') + if function_lib_id is None: + raise ValidationError(ErrorDetail( + _('The library ID of node {node} cannot be empty').format(node=node.properties.get("stepName")))) + f_lib = QuerySet(FunctionLib).filter(id=function_lib_id).first() + if f_lib is None: + raise ValidationError(ErrorDetail(_("The function library for node {node} is not available").format( + node=node.properties.get("stepName")))) + + def is_valid_base_node(self): + base_node_list = [node for node in self.nodes if node.id == 'base-node'] + if len(base_node_list) == 0: + raise AppApiException(500, _('Basic information node is required')) + if len(base_node_list) > 1: + raise AppApiException(500, _('There can only be one basic information node')) + + +class NodeResultFuture: + def __init__(self, r, e, status=200): + self.r = r + self.e = e + self.status = status + + def result(self): + if self.status == 200: + return self.r + else: + raise self.e + + +def await_result(result, timeout=1): + try: + result.result(timeout) + return False + except Exception as e: + return True + + +class NodeChunkManage: + + def __init__(self, work_flow): + self.node_chunk_list = [] + self.current_node_chunk = None + self.work_flow = work_flow + + def add_node_chunk(self, node_chunk): + self.node_chunk_list.append(node_chunk) + + def contains(self, node_chunk): + return self.node_chunk_list.__contains__(node_chunk) + + def pop(self): + if self.current_node_chunk is None: + try: + current_node_chunk = self.node_chunk_list.pop(0) + self.current_node_chunk = current_node_chunk + except IndexError as e: + pass + if self.current_node_chunk is not None: + try: + chunk = self.current_node_chunk.chunk_list.pop(0) + return chunk + except IndexError as e: + if self.current_node_chunk.is_end(): + self.current_node_chunk = None + if self.work_flow.answer_is_not_empty(): + chunk = self.work_flow.base_to_response.to_stream_chunk_response( + self.work_flow.params['chat_id'], + self.work_flow.params['chat_record_id'], + '\n\n', False, 0, 0) + self.work_flow.append_answer('\n\n') + return chunk + return self.pop() + return None + + +class WorkflowManage: + def __init__(self, flow: Flow, params, work_flow_post_handler: WorkFlowPostHandler, + base_to_response: BaseToResponse = SystemToResponse(), form_data=None, image_list=None, + document_list=None, + audio_list=None, + other_list=None, + start_node_id=None, + start_node_data=None, chat_record=None, child_node=None): + if form_data is None: + form_data = {} + if image_list is None: + image_list = [] + if document_list is None: + document_list = [] + if audio_list is None: + audio_list = [] + if other_list is None: + other_list = [] + self.start_node_id = start_node_id + self.start_node = None + self.form_data = form_data + self.image_list = image_list + self.document_list = document_list + self.audio_list = audio_list + self.other_list = other_list + self.params = params + self.flow = flow + self.context = {} + self.node_chunk_manage = NodeChunkManage(self) + self.work_flow_post_handler = work_flow_post_handler + self.current_node = None + self.current_result = None + self.answer = "" + self.answer_list = [''] + self.status = 200 + self.base_to_response = base_to_response + self.chat_record = chat_record + self.child_node = child_node + self.future_list = [] + self.lock = threading.Lock() + self.field_list = [] + self.global_field_list = [] + self.init_fields() + if start_node_id is not None: + self.load_node(chat_record, start_node_id, start_node_data) + else: + self.node_context = [] + + def init_fields(self): + field_list = [] + global_field_list = [] + for node in self.flow.nodes: + properties = node.properties + node_name = properties.get('stepName') + node_id = node.id + node_config = properties.get('config') + if node_config is not None: + fields = node_config.get('fields') + if fields is not None: + for field in fields: + field_list.append({**field, 'node_id': node_id, 'node_name': node_name}) + global_fields = node_config.get('globalFields') + if global_fields is not None: + for global_field in global_fields: + global_field_list.append({**global_field, 'node_id': node_id, 'node_name': node_name}) + field_list.sort(key=lambda f: len(f.get('node_name')), reverse=True) + global_field_list.sort(key=lambda f: len(f.get('node_name')), reverse=True) + self.field_list = field_list + self.global_field_list = global_field_list + + def append_answer(self, content): + self.answer += content + self.answer_list[-1] += content + + def answer_is_not_empty(self): + return len(self.answer_list[-1]) > 0 + + def load_node(self, chat_record, start_node_id, start_node_data): + self.node_context = [] + self.answer = chat_record.answer_text + self.answer_list = chat_record.answer_text_list + self.answer_list.append('') + for node_details in sorted(chat_record.details.values(), key=lambda d: d.get('index')): + node_id = node_details.get('node_id') + if node_details.get('runtime_node_id') == start_node_id: + def get_node_params(n): + is_result = False + if n.type == 'application-node': + is_result = True + return {**n.properties.get('node_data'), 'form_data': start_node_data, 'node_data': start_node_data, + 'child_node': self.child_node, 'is_result': is_result} + + self.start_node = self.get_node_cls_by_id(node_id, node_details.get('up_node_id_list'), + get_node_params=get_node_params) + self.start_node.valid_args( + {**self.start_node.node_params, 'form_data': start_node_data}, self.start_node.workflow_params) + if self.start_node.type == 'application-node': + application_node_dict = node_details.get('application_node_dict', {}) + self.start_node.context['application_node_dict'] = application_node_dict + self.node_context.append(self.start_node) + continue + + node_id = node_details.get('node_id') + node = self.get_node_cls_by_id(node_id, node_details.get('up_node_id_list')) + node.valid_args(node.node_params, node.workflow_params) + node.save_context(node_details, self) + node.node_chunk.end() + self.node_context.append(node) + + def run(self): + close_old_connections() + language = get_language() + if self.params.get('stream'): + return self.run_stream(self.start_node, None, language) + return self.run_block(language) + + def run_block(self, language='zh'): + """ + 非流式响应 + @return: 结果 + """ + self.run_chain_async(None, None, language) + while self.is_run(): + pass + details = self.get_runtime_details() + message_tokens = sum([row.get('message_tokens') for row in details.values() if + 'message_tokens' in row and row.get('message_tokens') is not None]) + answer_tokens = sum([row.get('answer_tokens') for row in details.values() if + 'answer_tokens' in row and row.get('answer_tokens') is not None]) + answer_text_list = self.get_answer_text_list() + answer_text = '\n\n'.join( + '\n\n'.join([a.get('content') for a in answer]) for answer in + answer_text_list) + answer_list = reduce(lambda pre, _n: [*pre, *_n], answer_text_list, []) + self.work_flow_post_handler.handler(self.params['chat_id'], self.params['chat_record_id'], + answer_text, + self) + return self.base_to_response.to_block_response(self.params['chat_id'], + self.params['chat_record_id'], answer_text, True + , message_tokens, answer_tokens, + _status=status.HTTP_200_OK if self.status == 200 else status.HTTP_500_INTERNAL_SERVER_ERROR, + other_params={'answer_list': answer_list}) + + def run_stream(self, current_node, node_result_future, language='zh'): + """ + 流式响应 + @return: + """ + self.run_chain_async(current_node, node_result_future, language) + return tools.to_stream_response_simple(self.await_result()) + + def is_run(self, timeout=0.5): + future_list_len = len(self.future_list) + try: + r = concurrent.futures.wait(self.future_list, timeout) + if len(r.not_done) > 0: + return True + else: + if future_list_len == len(self.future_list): + return False + else: + return True + except Exception as e: + return True + + def await_result(self): + try: + while self.is_run(): + while True: + chunk = self.node_chunk_manage.pop() + if chunk is not None: + yield chunk + else: + break + while True: + chunk = self.node_chunk_manage.pop() + if chunk is None: + break + yield chunk + finally: + while self.is_run(): + pass + details = self.get_runtime_details() + message_tokens = sum([row.get('message_tokens') for row in details.values() if + 'message_tokens' in row and row.get('message_tokens') is not None]) + answer_tokens = sum([row.get('answer_tokens') for row in details.values() if + 'answer_tokens' in row and row.get('answer_tokens') is not None]) + self.work_flow_post_handler.handler(self.params['chat_id'], self.params['chat_record_id'], + self.answer, + self) + yield self.base_to_response.to_stream_chunk_response(self.params['chat_id'], + self.params['chat_record_id'], + '', + [], + '', True, message_tokens, answer_tokens, {}) + + def run_chain_async(self, current_node, node_result_future, language='zh'): + future = executor.submit(self.run_chain_manage, current_node, node_result_future, language) + self.future_list.append(future) + + def run_chain_manage(self, current_node, node_result_future, language='zh'): + translation.activate(language) + if current_node is None: + start_node = self.get_start_node() + current_node = get_node(start_node.type)(start_node, self.params, self) + self.node_chunk_manage.add_node_chunk(current_node.node_chunk) + # 添加节点 + self.append_node(current_node) + result = self.run_chain(current_node, node_result_future) + if result is None: + return + node_list = self.get_next_node_list(current_node, result) + if len(node_list) == 1: + self.run_chain_manage(node_list[0], None, language) + elif len(node_list) > 1: + sorted_node_run_list = sorted(node_list, key=lambda n: n.node.y) + # 获取到可执行的子节点 + result_list = [{'node': node, 'future': executor.submit(self.run_chain_manage, node, None, language)} for + node in + sorted_node_run_list] + for r in result_list: + self.future_list.append(r.get('future')) + + def run_chain(self, current_node, node_result_future=None): + if node_result_future is None: + node_result_future = self.run_node_future(current_node) + try: + is_stream = self.params.get('stream', True) + result = self.hand_event_node_result(current_node, + node_result_future) if is_stream else self.hand_node_result( + current_node, node_result_future) + return result + except Exception as e: + traceback.print_exc() + return None + + def hand_node_result(self, current_node, node_result_future): + try: + current_result = node_result_future.result() + result = current_result.write_context(current_node, self) + if result is not None: + # 阻塞获取结果 + list(result) + return current_result + except Exception as e: + traceback.print_exc() + self.status = 500 + current_node.get_write_error_context(e) + self.answer += str(e) + finally: + current_node.node_chunk.end() + + def append_node(self, current_node): + for index in range(len(self.node_context)): + n = self.node_context[index] + if current_node.id == n.node.id and current_node.runtime_node_id == n.runtime_node_id: + self.node_context[index] = current_node + return + self.node_context.append(current_node) + + def hand_event_node_result(self, current_node, node_result_future): + runtime_node_id = current_node.runtime_node_id + real_node_id = current_node.runtime_node_id + child_node = {} + view_type = current_node.view_type + try: + current_result = node_result_future.result() + result = current_result.write_context(current_node, self) + if result is not None: + if self.is_result(current_node, current_result): + for r in result: + reasoning_content = '' + content = r + child_node = {} + node_is_end = False + view_type = current_node.view_type + if isinstance(r, dict): + content = r.get('content') + child_node = {'runtime_node_id': r.get('runtime_node_id'), + 'chat_record_id': r.get('chat_record_id') + , 'child_node': r.get('child_node')} + if r.__contains__('real_node_id'): + real_node_id = r.get('real_node_id') + if r.__contains__('node_is_end'): + node_is_end = r.get('node_is_end') + view_type = r.get('view_type') + reasoning_content = r.get('reasoning_content') + chunk = self.base_to_response.to_stream_chunk_response(self.params['chat_id'], + self.params['chat_record_id'], + current_node.id, + current_node.up_node_id_list, + content, False, 0, 0, + {'node_type': current_node.type, + 'runtime_node_id': runtime_node_id, + 'view_type': view_type, + 'child_node': child_node, + 'node_is_end': node_is_end, + 'real_node_id': real_node_id, + 'reasoning_content': reasoning_content}) + current_node.node_chunk.add_chunk(chunk) + chunk = (self.base_to_response + .to_stream_chunk_response(self.params['chat_id'], + self.params['chat_record_id'], + current_node.id, + current_node.up_node_id_list, + '', False, 0, 0, {'node_is_end': True, + 'runtime_node_id': runtime_node_id, + 'node_type': current_node.type, + 'view_type': view_type, + 'child_node': child_node, + 'real_node_id': real_node_id, + 'reasoning_content': ''})) + current_node.node_chunk.add_chunk(chunk) + else: + list(result) + return current_result + except Exception as e: + # 添加节点 + traceback.print_exc() + chunk = self.base_to_response.to_stream_chunk_response(self.params['chat_id'], + self.params['chat_record_id'], + current_node.id, + current_node.up_node_id_list, + 'Exception:' + str(e), False, 0, 0, + {'node_is_end': True, + 'runtime_node_id': current_node.runtime_node_id, + 'node_type': current_node.type, + 'view_type': current_node.view_type, + 'child_node': {}, + 'real_node_id': real_node_id}) + current_node.node_chunk.add_chunk(chunk) + current_node.get_write_error_context(e) + self.status = 500 + return None + finally: + current_node.node_chunk.end() + + def run_node_async(self, node): + future = executor.submit(self.run_node, node) + return future + + def run_node_future(self, node): + try: + node.valid_args(node.node_params, node.workflow_params) + result = self.run_node(node) + return NodeResultFuture(result, None, 200) + except Exception as e: + return NodeResultFuture(None, e, 500) + + def run_node(self, node): + result = node.run() + return result + + def is_result(self, current_node, current_node_result): + return current_node.node_params.get('is_result', not self._has_next_node( + current_node, current_node_result)) if current_node.node_params is not None else False + + def get_chunk_content(self, chunk, is_end=False): + return 'data: ' + json.dumps( + {'chat_id': self.params['chat_id'], 'id': self.params['chat_record_id'], 'operate': True, + 'content': chunk, 'is_end': is_end}, ensure_ascii=False) + "\n\n" + + def _has_next_node(self, current_node, node_result: NodeResult | None): + """ + 是否有下一个可运行的节点 + """ + if node_result is not None and node_result.is_assertion_result(): + for edge in self.flow.edges: + if (edge.sourceNodeId == current_node.id and + f"{edge.sourceNodeId}_{node_result.node_variable.get('branch_id')}_right" == edge.sourceAnchorId): + return True + else: + for edge in self.flow.edges: + if edge.sourceNodeId == current_node.id: + return True + + def has_next_node(self, node_result: NodeResult | None): + """ + 是否有下一个可运行的节点 + """ + return self._has_next_node(self.get_start_node() if self.current_node is None else self.current_node, + node_result) + + def get_runtime_details(self): + details_result = {} + for index in range(len(self.node_context)): + node = self.node_context[index] + if self.chat_record is not None and self.chat_record.details is not None: + details = self.chat_record.details.get(node.runtime_node_id) + if details is not None and self.start_node.runtime_node_id != node.runtime_node_id: + details_result[node.runtime_node_id] = details + continue + details = node.get_details(index) + details['node_id'] = node.id + details['up_node_id_list'] = node.up_node_id_list + details['runtime_node_id'] = node.runtime_node_id + details_result[node.runtime_node_id] = details + return details_result + + def get_answer_text_list(self): + result = [] + answer_list = reduce(lambda x, y: [*x, *y], + [n.get_answer_list() for n in self.node_context if n.get_answer_list() is not None], + []) + up_node = None + for index in range(len(answer_list)): + current_answer = answer_list[index] + if len(current_answer.content) > 0: + if up_node is None or current_answer.view_type == 'single_view' or ( + current_answer.view_type == 'many_view' and up_node.view_type == 'single_view'): + result.append([current_answer]) + else: + if len(result) > 0: + exec_index = len(result) - 1 + if isinstance(result[exec_index], list): + result[exec_index].append(current_answer) + else: + result.insert(0, [current_answer]) + up_node = current_answer + if len(result) == 0: + # 如果没有响应 就响应一个空数据 + return [[]] + return [[item.to_dict() for item in r] for r in result] + + def get_next_node(self): + """ + 获取下一个可运行的所有节点 + """ + if self.current_node is None: + node = self.get_start_node() + node_instance = get_node(node.type)(node, self.params, self) + return node_instance + if self.current_result is not None and self.current_result.is_assertion_result(): + for edge in self.flow.edges: + if (edge.sourceNodeId == self.current_node.id and + f"{edge.sourceNodeId}_{self.current_result.node_variable.get('branch_id')}_right" == edge.sourceAnchorId): + return self.get_node_cls_by_id(edge.targetNodeId) + else: + for edge in self.flow.edges: + if edge.sourceNodeId == self.current_node.id: + return self.get_node_cls_by_id(edge.targetNodeId) + + return None + + @staticmethod + def dependent_node(up_node_id, node): + if not node.node_chunk.is_end(): + return False + if node.id == up_node_id: + if node.type == 'form-node': + if node.context.get('form_data', None) is not None: + return True + return False + return True + + def dependent_node_been_executed(self, node_id): + """ + 判断依赖节点是否都已执行 + @param node_id: 需要判断的节点id + @return: + """ + up_node_id_list = [edge.sourceNodeId for edge in self.flow.edges if edge.targetNodeId == node_id] + return all([any([self.dependent_node(up_node_id, node) for node in self.node_context]) for up_node_id in + up_node_id_list]) + + def get_up_node_id_list(self, node_id): + up_node_id_list = [edge.sourceNodeId for edge in self.flow.edges if edge.targetNodeId == node_id] + return up_node_id_list + + def get_next_node_list(self, current_node, current_node_result): + """ + 获取下一个可执行节点列表 + @param current_node: 当前可执行节点 + @param current_node_result: 当前可执行节点结果 + @return: 可执行节点列表 + """ + # 判断是否中断执行 + if current_node_result.is_interrupt_exec(current_node): + return [] + node_list = [] + if current_node_result is not None and current_node_result.is_assertion_result(): + for edge in self.flow.edges: + if (edge.sourceNodeId == current_node.id and + f"{edge.sourceNodeId}_{current_node_result.node_variable.get('branch_id')}_right" == edge.sourceAnchorId): + next_node = [node for node in self.flow.nodes if node.id == edge.targetNodeId] + if len(next_node) == 0: + continue + if next_node[0].properties.get('condition', "AND") == 'AND': + if self.dependent_node_been_executed(edge.targetNodeId): + node_list.append( + self.get_node_cls_by_id(edge.targetNodeId, + [*current_node.up_node_id_list, current_node.node.id])) + else: + node_list.append( + self.get_node_cls_by_id(edge.targetNodeId, + [*current_node.up_node_id_list, current_node.node.id])) + else: + for edge in self.flow.edges: + if edge.sourceNodeId == current_node.id: + next_node = [node for node in self.flow.nodes if node.id == edge.targetNodeId] + if len(next_node) == 0: + continue + if next_node[0].properties.get('condition', "AND") == 'AND': + if self.dependent_node_been_executed(edge.targetNodeId): + node_list.append( + self.get_node_cls_by_id(edge.targetNodeId, + [*current_node.up_node_id_list, current_node.node.id])) + else: + node_list.append( + self.get_node_cls_by_id(edge.targetNodeId, + [*current_node.up_node_id_list, current_node.node.id])) + return node_list + + def get_reference_field(self, node_id: str, fields: List[str]): + """ + @param node_id: 节点id + @param fields: 字段 + @return: + """ + if node_id == 'global': + return INode.get_field(self.context, fields) + else: + return self.get_node_by_id(node_id).get_reference_field(fields) + + def get_workflow_content(self): + context = { + 'global': self.context, + } + + for node in self.node_context: + context[node.id] = node.context + return context + + def reset_prompt(self, prompt: str): + placeholder = "{}" + for field in self.field_list: + globeLabel = f"{field.get('node_name')}.{field.get('value')}" + globeValue = f"context.get('{field.get('node_id')}',{placeholder}).get('{field.get('value', '')}','')" + prompt = prompt.replace(globeLabel, globeValue) + for field in self.global_field_list: + globeLabel = f"全局变量.{field.get('value')}" + globeLabelNew = f"global.{field.get('value')}" + globeValue = f"context.get('global').get('{field.get('value', '')}','')" + prompt = prompt.replace(globeLabel, globeValue).replace(globeLabelNew, globeValue) + return prompt + + def generate_prompt(self, prompt: str): + """ + 格式化生成提示词 + @param prompt: 提示词信息 + @return: 格式化后的提示词 + """ + context = self.get_workflow_content() + prompt = self.reset_prompt(prompt) + prompt_template = PromptTemplate.from_template(prompt, template_format='jinja2') + value = prompt_template.format(context=context) + return value + + def get_start_node(self): + """ + 获取启动节点 + @return: + """ + start_node_list = [node for node in self.flow.nodes if node.type == 'start-node'] + return start_node_list[0] + + def get_base_node(self): + """ + 获取基础节点 + @return: + """ + base_node_list = [node for node in self.flow.nodes if node.type == 'base-node'] + return base_node_list[0] + + def get_node_cls_by_id(self, node_id, up_node_id_list=None, + get_node_params=lambda node: node.properties.get('node_data')): + for node in self.flow.nodes: + if node.id == node_id: + node_instance = get_node(node.type)(node, + self.params, self, up_node_id_list, get_node_params) + return node_instance + return None + + def get_node_by_id(self, node_id): + for node in self.node_context: + if node.id == node_id: + return node + return None + + def get_node_reference(self, reference_address: Dict): + node = self.get_node_by_id(reference_address.get('node_id')) + return node.context[reference_address.get('node_field')] diff --git a/apps/application/migrations/0001_initial.py b/apps/application/migrations/0001_initial.py index 707f303220b..52dadda82fd 100644 --- a/apps/application/migrations/0001_initial.py +++ b/apps/application/migrations/0001_initial.py @@ -1,4 +1,5 @@ -# Generated by Django 5.2.1 on 2025-05-27 06:42 +# Generated by Django 5.2 on 2025-05-27 07:50 +from django.db.models import QuerySet import application.models.application import django.contrib.postgres.fields @@ -9,11 +10,17 @@ from django.db import migrations, models -class Migration(migrations.Migration): +def insert_default_data(apps, schema_editor): + # 创建一个根模块(没有父节点) + QuerySet(application.models.application.ApplicationFolder).create(id='root', name='根目录', + user_id='f0dd8f71-e4ee-11ee-8c84-a8a1595801ab') + +class Migration(migrations.Migration): initial = True dependencies = [ + ('knowledge', '0001_initial'), ('models_provider', '0001_initial'), ('users', '0002_alter_user_nick_name'), ] @@ -24,22 +31,31 @@ class Migration(migrations.Migration): fields=[ ('create_time', models.DateTimeField(auto_now_add=True, verbose_name='创建时间')), ('update_time', models.DateTimeField(auto_now=True, verbose_name='修改时间')), - ('id', models.UUIDField(default=uuid_utils.compat.uuid7, editable=False, primary_key=True, serialize=False, verbose_name='主键id')), - ('workspace_id', models.CharField(db_index=True, default='default', max_length=64, verbose_name='工作空间id')), + ('id', + models.UUIDField(default=uuid_utils.compat.uuid7, editable=False, primary_key=True, serialize=False, + verbose_name='主键id')), + ('workspace_id', + models.CharField(db_index=True, default='default', max_length=64, verbose_name='工作空间id')), ('is_publish', models.BooleanField(default=False, verbose_name='是否发布')), ('name', models.CharField(max_length=128, verbose_name='应用名称')), ('desc', models.CharField(default='', max_length=512, verbose_name='引用描述')), ('prologue', models.CharField(default='', max_length=40960, verbose_name='开场白')), ('dialogue_number', models.IntegerField(default=0, verbose_name='会话数量')), - ('dataset_setting', models.JSONField(default=application.models.application.get_dataset_setting_dict, verbose_name='数据集参数设置')), - ('model_setting', models.JSONField(default=application.models.application.get_model_setting_dict, verbose_name='模型参数相关设置')), + ('knowledge_setting', models.JSONField(default=application.models.application.get_dataset_setting_dict, + verbose_name='数据集参数设置')), + ('model_setting', models.JSONField(default=application.models.application.get_model_setting_dict, + verbose_name='模型参数相关设置')), ('model_params_setting', models.JSONField(default=dict, verbose_name='模型参数相关设置')), ('tts_model_params_setting', models.JSONField(default=dict, verbose_name='模型参数相关设置')), ('problem_optimization', models.BooleanField(default=False, verbose_name='问题优化')), ('icon', models.CharField(default='/ui/favicon.ico', max_length=256, verbose_name='应用icon')), ('work_flow', models.JSONField(default=dict, verbose_name='工作流数据')), - ('type', models.CharField(choices=[('SIMPLE', '简易'), ('WORK_FLOW', '工作流')], default='SIMPLE', max_length=256, verbose_name='应用类型')), - ('problem_optimization_prompt', models.CharField(blank=True, default='()里面是用户问题,根据上下文回答揣测用户问题({question}) 要求: 输出一个补全问题,并且放在标签中', max_length=102400, null=True, verbose_name='问题优化提示词')), + ('type', models.CharField(choices=[('SIMPLE', '简易'), ('WORK_FLOW', '工作流')], default='SIMPLE', + max_length=256, verbose_name='应用类型')), + ('problem_optimization_prompt', models.CharField(blank=True, + default='()里面是用户问题,根据上下文回答揣测用户问题({question}) 要求: 输出一个补全问题,并且放在标签中', + max_length=102400, null=True, + verbose_name='问题优化提示词')), ('tts_model_enable', models.BooleanField(default=False, verbose_name='语音合成模型是否启用')), ('stt_model_enable', models.BooleanField(default=False, verbose_name='语音识别模型是否启用')), ('tts_type', models.CharField(default='BROWSER', max_length=20, verbose_name='语音播放类型')), @@ -48,9 +64,14 @@ class Migration(migrations.Migration): ('clean_time', models.IntegerField(default=180, verbose_name='清理时间')), ('file_upload_enable', models.BooleanField(default=False, verbose_name='文件上传是否启用')), ('file_upload_setting', models.JSONField(default=dict, verbose_name='文件上传相关设置')), - ('model', models.ForeignKey(blank=True, db_constraint=False, null=True, on_delete=django.db.models.deletion.SET_NULL, to='models_provider.model')), - ('stt_model', models.ForeignKey(blank=True, db_constraint=False, null=True, on_delete=django.db.models.deletion.SET_NULL, related_name='stt_model_id', to='models_provider.model')), - ('tts_model', models.ForeignKey(blank=True, db_constraint=False, null=True, on_delete=django.db.models.deletion.SET_NULL, related_name='tts_model_id', to='models_provider.model')), + ('model', models.ForeignKey(blank=True, db_constraint=False, null=True, + on_delete=django.db.models.deletion.SET_NULL, to='models_provider.model')), + ('stt_model', models.ForeignKey(blank=True, db_constraint=False, null=True, + on_delete=django.db.models.deletion.SET_NULL, + related_name='stt_model_id', to='models_provider.model')), + ('tts_model', models.ForeignKey(blank=True, db_constraint=False, null=True, + on_delete=django.db.models.deletion.SET_NULL, + related_name='tts_model_id', to='models_provider.model')), ('user', models.ForeignKey(on_delete=django.db.models.deletion.DO_NOTHING, to='users.user')), ], options={ @@ -62,12 +83,16 @@ class Migration(migrations.Migration): fields=[ ('create_time', models.DateTimeField(auto_now_add=True, verbose_name='创建时间')), ('update_time', models.DateTimeField(auto_now=True, verbose_name='修改时间')), - ('application', models.OneToOneField(on_delete=django.db.models.deletion.CASCADE, primary_key=True, serialize=False, to='application.application', verbose_name='应用id')), + ('application', + models.OneToOneField(on_delete=django.db.models.deletion.CASCADE, primary_key=True, serialize=False, + to='application.application', verbose_name='应用id')), ('access_token', models.CharField(max_length=128, unique=True, verbose_name='用户公开访问 认证token')), ('is_active', models.BooleanField(default=True, verbose_name='是否开启公开访问')), ('access_num', models.IntegerField(default=100, verbose_name='访问次数')), ('white_active', models.BooleanField(default=False, verbose_name='是否开启白名单')), - ('white_list', django.contrib.postgres.fields.ArrayField(base_field=models.CharField(blank=True, max_length=128), default=list, size=None, verbose_name='白名单列表')), + ('white_list', + django.contrib.postgres.fields.ArrayField(base_field=models.CharField(blank=True, max_length=128), + default=list, size=None, verbose_name='白名单列表')), ('show_source', models.BooleanField(default=False, verbose_name='是否显示知识来源')), ('language', models.CharField(default=None, max_length=10, null=True, verbose_name='语言')), ], @@ -80,14 +105,21 @@ class Migration(migrations.Migration): fields=[ ('create_time', models.DateTimeField(auto_now_add=True, verbose_name='创建时间')), ('update_time', models.DateTimeField(auto_now=True, verbose_name='修改时间')), - ('id', models.UUIDField(default=uuid.uuid1, editable=False, primary_key=True, serialize=False, verbose_name='主键id')), + ('id', models.UUIDField(default=uuid.uuid1, editable=False, primary_key=True, serialize=False, + verbose_name='主键id')), ('secret_key', models.CharField(max_length=1024, unique=True, verbose_name='秘钥')), - ('workspace_id', models.CharField(db_index=True, default='default', max_length=64, verbose_name='工作空间id')), + ('workspace_id', + models.CharField(db_index=True, default='default', max_length=64, verbose_name='工作空间id')), ('is_active', models.BooleanField(default=True, verbose_name='是否开启')), ('allow_cross_domain', models.BooleanField(default=False, verbose_name='是否允许跨域')), - ('cross_domain_list', django.contrib.postgres.fields.ArrayField(base_field=models.CharField(blank=True, max_length=128), default=list, size=None, verbose_name='跨域列表')), - ('application', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to='application.application', verbose_name='应用id')), - ('user', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to='users.user', verbose_name='用户id')), + ('cross_domain_list', + django.contrib.postgres.fields.ArrayField(base_field=models.CharField(blank=True, max_length=128), + default=list, size=None, verbose_name='跨域列表')), + ('application', + models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to='application.application', + verbose_name='应用id')), + ('user', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to='users.user', + verbose_name='用户id')), ], options={ 'db_table': 'application_api_key', @@ -98,16 +130,21 @@ class Migration(migrations.Migration): fields=[ ('create_time', models.DateTimeField(auto_now_add=True, verbose_name='创建时间')), ('update_time', models.DateTimeField(auto_now=True, verbose_name='修改时间')), - ('id', models.CharField(editable=False, max_length=64, primary_key=True, serialize=False, verbose_name='主键id')), + ('id', models.CharField(editable=False, max_length=64, primary_key=True, serialize=False, + verbose_name='主键id')), ('name', models.CharField(max_length=64, verbose_name='文件夹名称')), ('desc', models.CharField(blank=True, max_length=200, null=True, verbose_name='描述')), - ('workspace_id', models.CharField(db_index=True, default='default', max_length=64, verbose_name='工作空间id')), + ('workspace_id', + models.CharField(db_index=True, default='default', max_length=64, verbose_name='工作空间id')), ('lft', models.PositiveIntegerField(editable=False)), ('rght', models.PositiveIntegerField(editable=False)), ('tree_id', models.PositiveIntegerField(db_index=True, editable=False)), ('level', models.PositiveIntegerField(editable=False)), - ('parent', mptt.fields.TreeForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.DO_NOTHING, related_name='children', to='application.applicationfolder')), - ('user', models.ForeignKey(on_delete=django.db.models.deletion.DO_NOTHING, to='users.user', verbose_name='用户id')), + ('parent', + mptt.fields.TreeForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.DO_NOTHING, + related_name='children', to='application.applicationfolder')), + ('user', models.ForeignKey(on_delete=django.db.models.deletion.DO_NOTHING, to='users.user', + verbose_name='用户id')), ], options={ 'db_table': 'application_folder', @@ -116,18 +153,25 @@ class Migration(migrations.Migration): migrations.AddField( model_name='application', name='folder', - field=models.ForeignKey(default='root', on_delete=django.db.models.deletion.DO_NOTHING, to='application.applicationfolder', verbose_name='文件夹id'), + field=models.ForeignKey(default='root', on_delete=django.db.models.deletion.DO_NOTHING, + to='application.applicationfolder', verbose_name='文件夹id'), ), migrations.CreateModel( name='ApplicationKnowledgeMapping', fields=[ ('create_time', models.DateTimeField(auto_now_add=True, verbose_name='创建时间')), ('update_time', models.DateTimeField(auto_now=True, verbose_name='修改时间')), - ('id', models.UUIDField(default=uuid_utils.compat.uuid7, editable=False, primary_key=True, serialize=False, verbose_name='主键id')), - ('application', models.ForeignKey(on_delete=django.db.models.deletion.DO_NOTHING, to='application.application')), + ('id', + models.UUIDField(default=uuid_utils.compat.uuid7, editable=False, primary_key=True, serialize=False, + verbose_name='主键id')), + ('application', + models.ForeignKey(on_delete=django.db.models.deletion.DO_NOTHING, to='application.application')), + ('knowledge', + models.ForeignKey(on_delete=django.db.models.deletion.DO_NOTHING, to='knowledge.knowledge')), ], options={ 'db_table': 'application_knowledge_mapping', }, ), + migrations.RunPython(insert_default_data) ] diff --git a/apps/application/migrations/0002_initial.py b/apps/application/migrations/0002_initial.py deleted file mode 100644 index fbec8994e60..00000000000 --- a/apps/application/migrations/0002_initial.py +++ /dev/null @@ -1,22 +0,0 @@ -# Generated by Django 5.2.1 on 2025-05-27 06:42 - -import django.db.models.deletion -from django.db import migrations, models - - -class Migration(migrations.Migration): - - initial = True - - dependencies = [ - ('application', '0001_initial'), - ('knowledge', '0001_initial'), - ] - - operations = [ - migrations.AddField( - model_name='applicationknowledgemapping', - name='knowledge', - field=models.ForeignKey(on_delete=django.db.models.deletion.DO_NOTHING, to='knowledge.knowledge'), - ), - ] diff --git a/apps/application/models/application.py b/apps/application/models/application.py index 50ee7608729..243c14652dc 100644 --- a/apps/application/models/application.py +++ b/apps/application/models/application.py @@ -67,7 +67,7 @@ class Application(AppModelMixin): dialogue_number = models.IntegerField(default=0, verbose_name="会话数量") user = models.ForeignKey(User, on_delete=models.DO_NOTHING) model = models.ForeignKey(Model, on_delete=models.SET_NULL, db_constraint=False, blank=True, null=True) - dataset_setting = models.JSONField(verbose_name="数据集参数设置", default=get_dataset_setting_dict) + knowledge_setting = models.JSONField(verbose_name="数据集参数设置", default=get_dataset_setting_dict) model_setting = models.JSONField(verbose_name="模型参数相关设置", default=get_model_setting_dict) model_params_setting = models.JSONField(verbose_name="模型参数相关设置", default=dict) tts_model_params_setting = models.JSONField(verbose_name="模型参数相关设置", default=dict) diff --git a/apps/application/serializers/application.py b/apps/application/serializers/application.py index 879962290df..abb5a24f927 100644 --- a/apps/application/serializers/application.py +++ b/apps/application/serializers/application.py @@ -56,7 +56,7 @@ class ModelKnowledgeAssociation(serializers.Serializer): user_id = serializers.UUIDField(required=True, label=_("User ID")) model_id = serializers.CharField(required=False, allow_null=True, allow_blank=True, label=_("Model id")) - Knowledge_id_list = serializers.ListSerializer(required=False, child=serializers.UUIDField(required=True, + knowledge_id_list = serializers.ListSerializer(required=False, child=serializers.UUIDField(required=True, label=_( "Knowledge base id")), label=_("Knowledge Base List")) @@ -68,7 +68,7 @@ def is_valid(self, *, raise_exception=True): if model_id is not None and len(model_id) > 0: if not QuerySet(Model).filter(id=model_id).exists(): raise AppApiException(500, f'{_("Model does not exist")}【{model_id}】') - knowledge_id_list = list(set(self.data.get('knowledge_id_list'))) + knowledge_id_list = list(set(self.data.get('knowledge_id_list', []))) exist_knowledge_id_list = [str(knowledge.id) for knowledge in QuerySet(Knowledge).filter(id__in=knowledge_id_list, user_id=user_id)] for knowledge_id in knowledge_id_list: @@ -110,6 +110,7 @@ class WorkflowRequest(serializers.Serializer): work_flow = serializers.DictField(required=True, label=_("Workflow Objects")) prologue = serializers.CharField(required=False, allow_null=True, allow_blank=True, max_length=102400, label=_("Opening remarks")) + folder_id = serializers.CharField(required=True, label=_('folder id')) @staticmethod def to_application_model(user_id: str, workspace_id: str, application: Dict): @@ -123,6 +124,7 @@ def to_application_model(user_id: str, workspace_id: str, application: Dict): name=application.get('name'), desc=application.get('desc'), workspace_id=workspace_id, + folder_id=application.get('folder_id', 'root'), prologue="", dialogue_number=0, user_id=user_id, model_id=None, @@ -135,7 +137,7 @@ def to_application_model(user_id: str, workspace_id: str, application: Dict): tts_model_id=application.get('tts_model', None), tts_model_enable=application.get('tts_model_enable', False), tts_model_params_setting=application.get('tts_model_params_setting', {}), - tts_type=application.get('tts_type', None), + tts_type=application.get('tts_type', 'BROWSER'), file_upload_enable=application.get('file_upload_enable', False), file_upload_setting=application.get('file_upload_setting', {}), work_flow=default_workflow @@ -147,6 +149,7 @@ class SimplateRequest(serializers.Serializer): desc = serializers.CharField(required=False, allow_null=True, allow_blank=True, max_length=256, min_length=1, label=_("application describe")) + folder_id = serializers.CharField(required=True, label=_('folder id')) model_id = serializers.CharField(required=False, allow_null=True, allow_blank=True, label=_("Model")) dialogue_number = serializers.IntegerField(required=True, @@ -179,6 +182,20 @@ class SimplateRequest(serializers.Serializer): model_params_setting = serializers.DictField(required=False, label=_('Model parameters')) + tts_model_enable = serializers.BooleanField(required=False, label=_('Voice playback enabled')) + + tts_model_id = serializers.UUIDField(required=False, allow_null=True, label=_("Voice playback model ID")) + + tts_type = serializers.CharField(required=False, label=_('Voice playback type')) + + tts_autoplay = serializers.BooleanField(required=False, label=_('Voice playback autoplay')) + + stt_model_enable = serializers.BooleanField(required=False, label=_('Voice recognition enabled')) + + stt_model_id = serializers.UUIDField(required=False, allow_null=True, label=_('Speech recognition model ID')) + + stt_autosend = serializers.BooleanField(required=False, label=_('Voice recognition automatic transmission')) + def is_valid(self, *, user_id=None, raise_exception=False): super().is_valid(raise_exception=True) ModelKnowledgeAssociation(data={'user_id': user_id, 'model_id': self.data.get('model_id'), @@ -190,7 +207,8 @@ def to_application_model(user_id: str, application: Dict): prologue=application.get('prologue'), dialogue_number=application.get('dialogue_number', 0), user_id=user_id, model_id=application.get('model_id'), - dataset_setting=application.get('dataset_setting'), + folder_id=application.get('folder_id', 'root'), + knowledge_setting=application.get('knowledge_setting'), model_setting=application.get('model_setting'), problem_optimization=application.get('problem_optimization'), type=ApplicationTypeChoices.SIMPLE, @@ -198,10 +216,11 @@ def to_application_model(user_id: str, application: Dict): problem_optimization_prompt=application.get('problem_optimization_prompt', None), stt_model_enable=application.get('stt_model_enable', False), stt_model_id=application.get('stt_model', None), + stt_autosend=application.get('stt_autosend', False), tts_model_id=application.get('tts_model', None), tts_model_enable=application.get('tts_model_enable', False), tts_model_params_setting=application.get('tts_model_params_setting', {}), - tts_type=application.get('tts_type', None), + tts_type=application.get('tts_type', 'BROWSER'), file_upload_enable=application.get('file_upload_enable', False), file_upload_setting=application.get('file_upload_setting', {}), work_flow={} @@ -222,8 +241,10 @@ def insert(self, instance: Dict, with_valid=True): def insert_workflow(self, instance: Dict): self.is_valid(raise_exception=True) user_id = self.data.get('user_id') - ApplicationCreateSerializer.WorkflowRequest(data=instance).is_valid(raise_exception=True) - application_model = ApplicationCreateSerializer.WorkflowRequest.to_application_model(user_id, instance) + workspace_id = self.data.get('workspace_id') + wq = ApplicationCreateSerializer.WorkflowRequest(data=instance) + wq.is_valid(raise_exception=True) + application_model = wq.to_application_model(user_id, workspace_id, instance) application_model.save() # 插入认证信息 ApplicationAccessToken(application_id=application_model.id, diff --git a/apps/application/serializers/application_folder.py b/apps/application/serializers/application_folder.py new file mode 100644 index 00000000000..d79015e7499 --- /dev/null +++ b/apps/application/serializers/application_folder.py @@ -0,0 +1,21 @@ +from rest_framework import serializers + +from application.models import ApplicationFolder +from knowledge.models import KnowledgeFolder + + +class ApplicationFolderTreeSerializer(serializers.ModelSerializer): + children = serializers.SerializerMethodField() + + class Meta: + model = ApplicationFolder + fields = ['id', 'name', 'desc', 'user_id', 'workspace_id', 'parent_id', 'children'] + + def get_children(self, obj): + return ApplicationFolderTreeSerializer(obj.get_children(), many=True).data + + +class ApplicationFolderFlatSerializer(serializers.ModelSerializer): + class Meta: + model = ApplicationFolder + fields = ['id', 'name', 'desc', 'user_id', 'workspace_id', 'parent_id'] diff --git a/apps/folders/serializers/folder.py b/apps/folders/serializers/folder.py index 366b7e8c3d2..b2c86b83aa5 100644 --- a/apps/folders/serializers/folder.py +++ b/apps/folders/serializers/folder.py @@ -7,6 +7,7 @@ from rest_framework import serializers from application.models.application import Application, ApplicationFolder +from application.serializers.application_folder import ApplicationFolderTreeSerializer from common.constants.permission_constants import Group from folders.api.folder import FolderCreateRequest from knowledge.models import KnowledgeFolder, Knowledge @@ -42,9 +43,7 @@ def get_folder_tree_serializer(source): if source == Group.TOOL.name: return ToolFolderTreeSerializer elif source == Group.APPLICATION.name: - # todo app folder - return None - # return ApplicationFolderTreeSerializer + return ApplicationFolderTreeSerializer elif source == Group.KNOWLEDGE.name: return KnowledgeFolderTreeSerializer else: