-
Notifications
You must be signed in to change notification settings - Fork 2.2k
feat: application flow #3152
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
feat: application flow #3152
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,157 @@ | ||
# coding=utf-8 | ||
""" | ||
@project: maxkb | ||
@Author:虎 | ||
@file: I_base_chat_pipeline.py | ||
@date:2024/1/9 17:25 | ||
@desc: | ||
""" | ||
import time | ||
from abc import abstractmethod | ||
from typing import Type | ||
|
||
from rest_framework import serializers | ||
|
||
from dataset.models import Paragraph | ||
|
||
|
||
class ParagraphPipelineModel: | ||
|
||
def __init__(self, _id: str, document_id: str, dataset_id: str, content: str, title: str, status: str, | ||
is_active: bool, comprehensive_score: float, similarity: float, dataset_name: str, document_name: str, | ||
hit_handling_method: str, directly_return_similarity: float, meta: dict = None): | ||
self.id = _id | ||
self.document_id = document_id | ||
self.dataset_id = dataset_id | ||
self.content = content | ||
self.title = title | ||
self.status = status, | ||
self.is_active = is_active | ||
self.comprehensive_score = comprehensive_score | ||
self.similarity = similarity | ||
self.dataset_name = dataset_name | ||
self.document_name = document_name | ||
self.hit_handling_method = hit_handling_method | ||
self.directly_return_similarity = directly_return_similarity | ||
self.meta = meta | ||
|
||
def to_dict(self): | ||
return { | ||
'id': self.id, | ||
'document_id': self.document_id, | ||
'dataset_id': self.dataset_id, | ||
'content': self.content, | ||
'title': self.title, | ||
'status': self.status, | ||
'is_active': self.is_active, | ||
'comprehensive_score': self.comprehensive_score, | ||
'similarity': self.similarity, | ||
'dataset_name': self.dataset_name, | ||
'document_name': self.document_name, | ||
'meta': self.meta, | ||
} | ||
|
||
class builder: | ||
def __init__(self): | ||
self.similarity = None | ||
self.paragraph = {} | ||
self.comprehensive_score = None | ||
self.document_name = None | ||
self.dataset_name = None | ||
self.hit_handling_method = None | ||
self.directly_return_similarity = 0.9 | ||
self.meta = {} | ||
|
||
def add_paragraph(self, paragraph): | ||
if isinstance(paragraph, Paragraph): | ||
self.paragraph = {'id': paragraph.id, | ||
'document_id': paragraph.document_id, | ||
'dataset_id': paragraph.dataset_id, | ||
'content': paragraph.content, | ||
'title': paragraph.title, | ||
'status': paragraph.status, | ||
'is_active': paragraph.is_active, | ||
} | ||
else: | ||
self.paragraph = paragraph | ||
return self | ||
|
||
def add_dataset_name(self, dataset_name): | ||
self.dataset_name = dataset_name | ||
return self | ||
|
||
def add_document_name(self, document_name): | ||
self.document_name = document_name | ||
return self | ||
|
||
def add_hit_handling_method(self, hit_handling_method): | ||
self.hit_handling_method = hit_handling_method | ||
return self | ||
|
||
def add_directly_return_similarity(self, directly_return_similarity): | ||
self.directly_return_similarity = directly_return_similarity | ||
return self | ||
|
||
def add_comprehensive_score(self, comprehensive_score: float): | ||
self.comprehensive_score = comprehensive_score | ||
return self | ||
|
||
def add_similarity(self, similarity: float): | ||
self.similarity = similarity | ||
return self | ||
|
||
def add_meta(self, meta: dict): | ||
self.meta = meta | ||
return self | ||
|
||
def build(self): | ||
return ParagraphPipelineModel(str(self.paragraph.get('id')), str(self.paragraph.get('document_id')), | ||
str(self.paragraph.get('dataset_id')), | ||
self.paragraph.get('content'), self.paragraph.get('title'), | ||
self.paragraph.get('status'), | ||
self.paragraph.get('is_active'), | ||
self.comprehensive_score, self.similarity, self.dataset_name, | ||
self.document_name, self.hit_handling_method, self.directly_return_similarity, | ||
self.meta) | ||
|
||
|
||
class IBaseChatPipelineStep: | ||
def __init__(self): | ||
# 当前步骤上下文,用于存储当前步骤信息 | ||
self.context = {} | ||
|
||
@abstractmethod | ||
def get_step_serializer(self, manage) -> Type[serializers.Serializer]: | ||
pass | ||
|
||
def valid_args(self, manage): | ||
step_serializer_clazz = self.get_step_serializer(manage) | ||
step_serializer = step_serializer_clazz(data=manage.context) | ||
step_serializer.is_valid(raise_exception=True) | ||
self.context['step_args'] = step_serializer.data | ||
|
||
def run(self, manage): | ||
""" | ||
|
||
:param manage: 步骤管理器 | ||
:return: 执行结果 | ||
""" | ||
start_time = time.time() | ||
self.context['start_time'] = start_time | ||
# 校验参数, | ||
self.valid_args(manage) | ||
self._run(manage) | ||
self.context['run_time'] = time.time() - start_time | ||
|
||
def _run(self, manage): | ||
pass | ||
|
||
def execute(self, **kwargs): | ||
pass | ||
|
||
def get_details(self, manage, **kwargs): | ||
""" | ||
运行详情 | ||
:return: 步骤详情 | ||
""" | ||
return None | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
# coding=utf-8 | ||
""" | ||
@project: maxkb | ||
@Author:虎 | ||
@file: __init__.py.py | ||
@date:2024/1/9 17:23 | ||
@desc: | ||
""" |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
# coding=utf-8 | ||
""" | ||
@project: maxkb | ||
@Author:虎 | ||
@file: pipeline_manage.py | ||
@date:2024/1/9 17:40 | ||
@desc: | ||
""" | ||
import time | ||
from functools import reduce | ||
from typing import List, Type, Dict | ||
|
||
from application.chat_pipeline.I_base_chat_pipeline import IBaseChatPipelineStep | ||
from common.handle.base_to_response import BaseToResponse | ||
from common.handle.impl.response.system_to_response import SystemToResponse | ||
|
||
|
||
class PipelineManage: | ||
def __init__(self, step_list: List[Type[IBaseChatPipelineStep]], | ||
base_to_response: BaseToResponse = SystemToResponse()): | ||
# 步骤执行器 | ||
self.step_list = [step() for step in step_list] | ||
# 上下文 | ||
self.context = {'message_tokens': 0, 'answer_tokens': 0} | ||
self.base_to_response = base_to_response | ||
|
||
def run(self, context: Dict = None): | ||
self.context['start_time'] = time.time() | ||
if context is not None: | ||
for key, value in context.items(): | ||
self.context[key] = value | ||
for step in self.step_list: | ||
step.run(self) | ||
|
||
def get_details(self): | ||
return reduce(lambda x, y: {**x, **y}, [{item.get('step_type'): item} for item in | ||
filter(lambda r: r is not None, | ||
[row.get_details(self) for row in self.step_list])], {}) | ||
|
||
def get_base_to_response(self): | ||
return self.base_to_response | ||
|
||
class builder: | ||
def __init__(self): | ||
self.step_list: List[Type[IBaseChatPipelineStep]] = [] | ||
self.base_to_response = SystemToResponse() | ||
|
||
def append_step(self, step: Type[IBaseChatPipelineStep]): | ||
self.step_list.append(step) | ||
return self | ||
|
||
def add_base_to_response(self, base_to_response: BaseToResponse): | ||
self.base_to_response = base_to_response | ||
return self | ||
|
||
def build(self): | ||
return PipelineManage(step_list=self.step_list, base_to_response=self.base_to_response) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
# coding=utf-8 | ||
""" | ||
@project: maxkb | ||
@Author:虎 | ||
@file: __init__.py.py | ||
@date:2024/1/9 18:23 | ||
@desc: | ||
""" |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
# coding=utf-8 | ||
""" | ||
@project: maxkb | ||
@Author:虎 | ||
@file: __init__.py.py | ||
@date:2024/1/9 18:23 | ||
@desc: | ||
""" |
110 changes: 110 additions & 0 deletions
110
apps/application/chat_pipeline/step/chat_step/i_chat_step.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,110 @@ | ||
# coding=utf-8 | ||
""" | ||
@project: maxkb | ||
@Author:虎 | ||
@file: i_chat_step.py | ||
@date:2024/1/9 18:17 | ||
@desc: 对话 | ||
""" | ||
from abc import abstractmethod | ||
from typing import Type, List | ||
|
||
from django.utils.translation import gettext_lazy as _ | ||
from langchain.chat_models.base import BaseChatModel | ||
from langchain.schema import BaseMessage | ||
from rest_framework import serializers | ||
|
||
from application.chat_pipeline.I_base_chat_pipeline import IBaseChatPipelineStep, ParagraphPipelineModel | ||
from application.chat_pipeline.pipeline_manage import PipelineManage | ||
from application.serializers.application_serializers import NoReferencesSetting | ||
from common.field.common import InstanceField | ||
from common.util.field_message import ErrMessage | ||
|
||
|
||
class ModelField(serializers.Field): | ||
def to_internal_value(self, data): | ||
if not isinstance(data, BaseChatModel): | ||
self.fail(_('Model type error'), value=data) | ||
return data | ||
|
||
def to_representation(self, value): | ||
return value | ||
|
||
|
||
class MessageField(serializers.Field): | ||
def to_internal_value(self, data): | ||
if not isinstance(data, BaseMessage): | ||
self.fail(_('Message type error'), value=data) | ||
return data | ||
|
||
def to_representation(self, value): | ||
return value | ||
|
||
|
||
class PostResponseHandler: | ||
@abstractmethod | ||
def handler(self, chat_id, chat_record_id, paragraph_list: List[ParagraphPipelineModel], problem_text: str, | ||
answer_text, | ||
manage, step, padding_problem_text: str = None, client_id=None, **kwargs): | ||
pass | ||
|
||
|
||
class IChatStep(IBaseChatPipelineStep): | ||
class InstanceSerializer(serializers.Serializer): | ||
# 对话列表 | ||
message_list = serializers.ListField(required=True, child=MessageField(required=True), | ||
error_messages=ErrMessage.list(_("Conversation list"))) | ||
model_id = serializers.UUIDField(required=False, allow_null=True, error_messages=ErrMessage.uuid(_("Model id"))) | ||
# 段落列表 | ||
paragraph_list = serializers.ListField(error_messages=ErrMessage.list(_("Paragraph List"))) | ||
# 对话id | ||
chat_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid(_("Conversation ID"))) | ||
# 用户问题 | ||
problem_text = serializers.CharField(required=True, error_messages=ErrMessage.uuid(_("User Questions"))) | ||
# 后置处理器 | ||
post_response_handler = InstanceField(model_type=PostResponseHandler, | ||
error_messages=ErrMessage.base(_("Post-processor"))) | ||
# 补全问题 | ||
padding_problem_text = serializers.CharField(required=False, | ||
error_messages=ErrMessage.base(_("Completion Question"))) | ||
# 是否使用流的形式输出 | ||
stream = serializers.BooleanField(required=False, error_messages=ErrMessage.base(_("Streaming Output"))) | ||
client_id = serializers.CharField(required=True, error_messages=ErrMessage.char(_("Client id"))) | ||
client_type = serializers.CharField(required=True, error_messages=ErrMessage.char(_("Client Type"))) | ||
# 未查询到引用分段 | ||
no_references_setting = NoReferencesSetting(required=True, | ||
error_messages=ErrMessage.base(_("No reference segment settings"))) | ||
|
||
user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid(_("User ID"))) | ||
|
||
model_setting = serializers.DictField(required=True, allow_null=True, | ||
error_messages=ErrMessage.dict(_("Model settings"))) | ||
|
||
model_params_setting = serializers.DictField(required=False, allow_null=True, | ||
error_messages=ErrMessage.dict(_("Model parameter settings"))) | ||
|
||
def is_valid(self, *, raise_exception=False): | ||
super().is_valid(raise_exception=True) | ||
message_list: List = self.initial_data.get('message_list') | ||
for message in message_list: | ||
if not isinstance(message, BaseMessage): | ||
raise Exception(_("message type error")) | ||
|
||
def get_step_serializer(self, manage: PipelineManage) -> Type[serializers.Serializer]: | ||
return self.InstanceSerializer | ||
|
||
def _run(self, manage: PipelineManage): | ||
chat_result = self.execute(**self.context['step_args'], manage=manage) | ||
manage.context['chat_result'] = chat_result | ||
|
||
@abstractmethod | ||
def execute(self, message_list: List[BaseMessage], | ||
chat_id, problem_text, | ||
post_response_handler: PostResponseHandler, | ||
model_id: str = None, | ||
user_id: str = None, | ||
paragraph_list=None, | ||
manage: PipelineManage = None, | ||
padding_problem_text: str = None, stream: bool = True, client_id=None, client_type=None, | ||
no_references_setting=None, model_params_setting=None, model_setting=None, **kwargs): | ||
pass |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The provided code has several areas that need attention:
@coding=utf-8
directive is unnecessary and should be removed.ParagraphPipelineModel
,IBaseChatPipelineStep
, etc., are not in consistent PascalCase format, which can make them harder to read._id
,document_id
,dataset_id
, etc., are short and don't provide enough context about their purpose.I_BaseChatPipelineStep.run()
method contains comments but doesn't actually implement it.Here are some specific suggestions:
General Refactoring
1. Rename Variables
# Change variable names with descriptive names
2. Improve Docstrings
3. Use Context Dictionary Consistently
Ensure that all classes have access to a context dictionary through inheritance if necessary.
Code Changes
Constructor Updates
Update constructor parameter names to be more descriptive.
Correct Serialization
Ensure serialization logic handles cases where
None
values might occur.Implement
run()
MethodImplement the abstract
run()
method by calling other necessary parts.By making these changes, you'll improve the clarity, maintainability, and robustness of the given codebase according to Python guidelines.