Skip to content

feat: application flow #3152

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 27, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
157 changes: 157 additions & 0 deletions apps/application/chat_pipeline/I_base_chat_pipeline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
# coding=utf-8
"""
@project: maxkb
@Author:虎
@file: I_base_chat_pipeline.py
@date:2024/1/9 17:25
@desc:
"""
import time
from abc import abstractmethod
from typing import Type

from rest_framework import serializers

from dataset.models import Paragraph


class ParagraphPipelineModel:

def __init__(self, _id: str, document_id: str, dataset_id: str, content: str, title: str, status: str,
is_active: bool, comprehensive_score: float, similarity: float, dataset_name: str, document_name: str,
hit_handling_method: str, directly_return_similarity: float, meta: dict = None):
self.id = _id
self.document_id = document_id
self.dataset_id = dataset_id
self.content = content
self.title = title
self.status = status,
self.is_active = is_active
self.comprehensive_score = comprehensive_score
self.similarity = similarity
self.dataset_name = dataset_name
self.document_name = document_name
self.hit_handling_method = hit_handling_method
self.directly_return_similarity = directly_return_similarity
self.meta = meta

def to_dict(self):
return {
'id': self.id,
'document_id': self.document_id,
'dataset_id': self.dataset_id,
'content': self.content,
'title': self.title,
'status': self.status,
'is_active': self.is_active,
'comprehensive_score': self.comprehensive_score,
'similarity': self.similarity,
'dataset_name': self.dataset_name,
'document_name': self.document_name,
'meta': self.meta,
}

class builder:
def __init__(self):
self.similarity = None
self.paragraph = {}
self.comprehensive_score = None
self.document_name = None
self.dataset_name = None
self.hit_handling_method = None
self.directly_return_similarity = 0.9
self.meta = {}

def add_paragraph(self, paragraph):
if isinstance(paragraph, Paragraph):
self.paragraph = {'id': paragraph.id,
'document_id': paragraph.document_id,
'dataset_id': paragraph.dataset_id,
'content': paragraph.content,
'title': paragraph.title,
'status': paragraph.status,
'is_active': paragraph.is_active,
}
else:
self.paragraph = paragraph
return self

def add_dataset_name(self, dataset_name):
self.dataset_name = dataset_name
return self

def add_document_name(self, document_name):
self.document_name = document_name
return self

def add_hit_handling_method(self, hit_handling_method):
self.hit_handling_method = hit_handling_method
return self

def add_directly_return_similarity(self, directly_return_similarity):
self.directly_return_similarity = directly_return_similarity
return self

def add_comprehensive_score(self, comprehensive_score: float):
self.comprehensive_score = comprehensive_score
return self

def add_similarity(self, similarity: float):
self.similarity = similarity
return self

def add_meta(self, meta: dict):
self.meta = meta
return self

def build(self):
return ParagraphPipelineModel(str(self.paragraph.get('id')), str(self.paragraph.get('document_id')),
str(self.paragraph.get('dataset_id')),
self.paragraph.get('content'), self.paragraph.get('title'),
self.paragraph.get('status'),
self.paragraph.get('is_active'),
self.comprehensive_score, self.similarity, self.dataset_name,
self.document_name, self.hit_handling_method, self.directly_return_similarity,
self.meta)


class IBaseChatPipelineStep:
def __init__(self):
# 当前步骤上下文,用于存储当前步骤信息
self.context = {}

@abstractmethod
def get_step_serializer(self, manage) -> Type[serializers.Serializer]:
pass

def valid_args(self, manage):
step_serializer_clazz = self.get_step_serializer(manage)
step_serializer = step_serializer_clazz(data=manage.context)
step_serializer.is_valid(raise_exception=True)
self.context['step_args'] = step_serializer.data

def run(self, manage):
"""

:param manage: 步骤管理器
:return: 执行结果
"""
start_time = time.time()
self.context['start_time'] = start_time
# 校验参数,
self.valid_args(manage)
self._run(manage)
self.context['run_time'] = time.time() - start_time

def _run(self, manage):
pass

def execute(self, **kwargs):
pass

def get_details(self, manage, **kwargs):
"""
运行详情
:return: 步骤详情
"""
return None
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The provided code has several areas that need attention:

  1. Imports: The @coding=utf-8 directive is unnecessary and should be removed.
  2. Class Names: Class names like ParagraphPipelineModel, IBaseChatPipelineStep, etc., are not in consistent PascalCase format, which can make them harder to read.
  3. Variable Naming: Variable names such as _id, document_id, dataset_id, etc., are short and don't provide enough context about their purpose.
  4. Docstrings: Most docstrings are missing descriptions of what each method does and how they handle exceptions.
  5. Type Annotations: While type annotations are good practice, ensure they match the actual data types used in the code.
  6. Abstract Method Implementation: The I_BaseChatPipelineStep.run() method contains comments but doesn't actually implement it.
  7. Directly Returning Results: Methods returning results directly within constructors or methods without considering whether they need parameters.

Here are some specific suggestions:

General Refactoring

1. Rename Variables

# Change variable names with descriptive names

2. Improve Docstrings

"""
    @project: maxkb
    @Author:虎
    @file: I_base_chat_pipeline.py
    @date:2024/1/9 17:25
    @desc:
        This file defines the base interface for chat pipeline steps.
"""

3. Use Context Dictionary Consistently

Ensure that all classes have access to a context dictionary through inheritance if necessary.

Code Changes

Constructor Updates

Update constructor parameter names to be more descriptive.

class ParagraphPipelineModel:
    def __init__(
        self,
        paragraph_id: str,
        document_id: int,
        dataset_id: int,
        content: str,
        title: str,
        status: str,
        is_active: bool,
        comprehensive_score: float,
        similarity: float,
        dataset_name: str,
        document_name: str,
        hit_handling_method: str,
        directly_return_similarity: float,
        meta: dict = None,
    ) -> None:
        super().__init__()
        self._paragraph_id = paragraph_id
        self._document_id = document_id
        self._dataset_id = dataset_id
        self._content = content
        self.title = title
        self.status = status,
        self.is_active = is_active
        self.comprehensive_score = comprehensive_score
        self.similarity = similarity
        self.dataset_name = dataset_name
        self.document_name = document_name
        self.hit_handling_method = hit_handling_method
        self.directly_return_similarity = directly_return_similarity
        self.meta = meta

Correct Serialization

Ensure serialization logic handles cases where None values might occur.

def to_dict(self) -> Dict[str, Union[str, int, float, bool, list]]:
    return {
        'id': str(self._paragraph_id),
        'document_id': self._document_id,
        'dataset_id': self._dataset_id,
        'content': self._content,
        'title': self.title,
        'status': self.status,
        'is_active': self.is_active,
        'comprehensive_score': self.comprehensive_score,
        'similarity': self.similarity,
        'dataset_name': self.dataset_name,
        'document_name': self.document_name,
        'hit_handling_method': self.hit_handling_method,
        'directly_return_similarity': self.directly_return_similarity,
        'meta': self.meta,
    }

Implement run() Method

Implement the abstract run() method by calling other necessary parts.

def run(self, manage) -> Any:
    """
        Execute this step.
        :return Execution result.
        """
    start_time = time.time()
    self.context['start_time'] = start_time
    
    # Check arguments validity
    args_validated = False
    try:
        self.valid_args(manage)
        args_validated = True
    except serializers.ValidationError:
        manage.set_error_status("Argument validation failed.")
    
    # Run the main flow
    if args_validated:
        self._run(manage)
        self.context['run_time'] = time.time() - start_time
    else:
        manage.set_error_status("Failed to validate argument.")

By making these changes, you'll improve the clarity, maintainability, and robustness of the given codebase according to Python guidelines.

8 changes: 8 additions & 0 deletions apps/application/chat_pipeline/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# coding=utf-8
"""
@project: maxkb
@Author:虎
@file: __init__.py.py
@date:2024/1/9 17:23
@desc:
"""
57 changes: 57 additions & 0 deletions apps/application/chat_pipeline/pipeline_manage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# coding=utf-8
"""
@project: maxkb
@Author:虎
@file: pipeline_manage.py
@date:2024/1/9 17:40
@desc:
"""
import time
from functools import reduce
from typing import List, Type, Dict

from application.chat_pipeline.I_base_chat_pipeline import IBaseChatPipelineStep
from common.handle.base_to_response import BaseToResponse
from common.handle.impl.response.system_to_response import SystemToResponse


class PipelineManage:
def __init__(self, step_list: List[Type[IBaseChatPipelineStep]],
base_to_response: BaseToResponse = SystemToResponse()):
# 步骤执行器
self.step_list = [step() for step in step_list]
# 上下文
self.context = {'message_tokens': 0, 'answer_tokens': 0}
self.base_to_response = base_to_response

def run(self, context: Dict = None):
self.context['start_time'] = time.time()
if context is not None:
for key, value in context.items():
self.context[key] = value
for step in self.step_list:
step.run(self)

def get_details(self):
return reduce(lambda x, y: {**x, **y}, [{item.get('step_type'): item} for item in
filter(lambda r: r is not None,
[row.get_details(self) for row in self.step_list])], {})

def get_base_to_response(self):
return self.base_to_response

class builder:
def __init__(self):
self.step_list: List[Type[IBaseChatPipelineStep]] = []
self.base_to_response = SystemToResponse()

def append_step(self, step: Type[IBaseChatPipelineStep]):
self.step_list.append(step)
return self

def add_base_to_response(self, base_to_response: BaseToResponse):
self.base_to_response = base_to_response
return self

def build(self):
return PipelineManage(step_list=self.step_list, base_to_response=self.base_to_response)
8 changes: 8 additions & 0 deletions apps/application/chat_pipeline/step/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# coding=utf-8
"""
@project: maxkb
@Author:虎
@file: __init__.py.py
@date:2024/1/9 18:23
@desc:
"""
8 changes: 8 additions & 0 deletions apps/application/chat_pipeline/step/chat_step/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# coding=utf-8
"""
@project: maxkb
@Author:虎
@file: __init__.py.py
@date:2024/1/9 18:23
@desc:
"""
110 changes: 110 additions & 0 deletions apps/application/chat_pipeline/step/chat_step/i_chat_step.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
# coding=utf-8
"""
@project: maxkb
@Author:虎
@file: i_chat_step.py
@date:2024/1/9 18:17
@desc: 对话
"""
from abc import abstractmethod
from typing import Type, List

from django.utils.translation import gettext_lazy as _
from langchain.chat_models.base import BaseChatModel
from langchain.schema import BaseMessage
from rest_framework import serializers

from application.chat_pipeline.I_base_chat_pipeline import IBaseChatPipelineStep, ParagraphPipelineModel
from application.chat_pipeline.pipeline_manage import PipelineManage
from application.serializers.application_serializers import NoReferencesSetting
from common.field.common import InstanceField
from common.util.field_message import ErrMessage


class ModelField(serializers.Field):
def to_internal_value(self, data):
if not isinstance(data, BaseChatModel):
self.fail(_('Model type error'), value=data)
return data

def to_representation(self, value):
return value


class MessageField(serializers.Field):
def to_internal_value(self, data):
if not isinstance(data, BaseMessage):
self.fail(_('Message type error'), value=data)
return data

def to_representation(self, value):
return value


class PostResponseHandler:
@abstractmethod
def handler(self, chat_id, chat_record_id, paragraph_list: List[ParagraphPipelineModel], problem_text: str,
answer_text,
manage, step, padding_problem_text: str = None, client_id=None, **kwargs):
pass


class IChatStep(IBaseChatPipelineStep):
class InstanceSerializer(serializers.Serializer):
# 对话列表
message_list = serializers.ListField(required=True, child=MessageField(required=True),
error_messages=ErrMessage.list(_("Conversation list")))
model_id = serializers.UUIDField(required=False, allow_null=True, error_messages=ErrMessage.uuid(_("Model id")))
# 段落列表
paragraph_list = serializers.ListField(error_messages=ErrMessage.list(_("Paragraph List")))
# 对话id
chat_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid(_("Conversation ID")))
# 用户问题
problem_text = serializers.CharField(required=True, error_messages=ErrMessage.uuid(_("User Questions")))
# 后置处理器
post_response_handler = InstanceField(model_type=PostResponseHandler,
error_messages=ErrMessage.base(_("Post-processor")))
# 补全问题
padding_problem_text = serializers.CharField(required=False,
error_messages=ErrMessage.base(_("Completion Question")))
# 是否使用流的形式输出
stream = serializers.BooleanField(required=False, error_messages=ErrMessage.base(_("Streaming Output")))
client_id = serializers.CharField(required=True, error_messages=ErrMessage.char(_("Client id")))
client_type = serializers.CharField(required=True, error_messages=ErrMessage.char(_("Client Type")))
# 未查询到引用分段
no_references_setting = NoReferencesSetting(required=True,
error_messages=ErrMessage.base(_("No reference segment settings")))

user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid(_("User ID")))

model_setting = serializers.DictField(required=True, allow_null=True,
error_messages=ErrMessage.dict(_("Model settings")))

model_params_setting = serializers.DictField(required=False, allow_null=True,
error_messages=ErrMessage.dict(_("Model parameter settings")))

def is_valid(self, *, raise_exception=False):
super().is_valid(raise_exception=True)
message_list: List = self.initial_data.get('message_list')
for message in message_list:
if not isinstance(message, BaseMessage):
raise Exception(_("message type error"))

def get_step_serializer(self, manage: PipelineManage) -> Type[serializers.Serializer]:
return self.InstanceSerializer

def _run(self, manage: PipelineManage):
chat_result = self.execute(**self.context['step_args'], manage=manage)
manage.context['chat_result'] = chat_result

@abstractmethod
def execute(self, message_list: List[BaseMessage],
chat_id, problem_text,
post_response_handler: PostResponseHandler,
model_id: str = None,
user_id: str = None,
paragraph_list=None,
manage: PipelineManage = None,
padding_problem_text: str = None, stream: bool = True, client_id=None, client_type=None,
no_references_setting=None, model_params_setting=None, model_setting=None, **kwargs):
pass
Loading
Loading