Skip to content

Commit

Permalink
* feat: 增加无引用分段设置
Browse files Browse the repository at this point in the history
  • Loading branch information
shaohuzhang1 authored Apr 24, 2024
1 parent a3f4710 commit 1f522dc
Show file tree
Hide file tree
Showing 12 changed files with 218 additions and 48 deletions.
7 changes: 5 additions & 2 deletions apps/application/chat_pipeline/step/chat_step/i_chat_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@

from application.chat_pipeline.I_base_chat_pipeline import IBaseChatPipelineStep, ParagraphPipelineModel
from application.chat_pipeline.pipeline_manage import PipelineManage
from application.serializers.application_serializers import NoReferencesSetting
from common.field.common import InstanceField
from common.util.field_message import ErrMessage
from dataset.models import Paragraph


class ModelField(serializers.Field):
Expand Down Expand Up @@ -70,6 +70,8 @@ class InstanceSerializer(serializers.Serializer):
stream = serializers.BooleanField(required=False, error_messages=ErrMessage.base("流式输出"))
client_id = serializers.CharField(required=True, error_messages=ErrMessage.char("客户端id"))
client_type = serializers.CharField(required=True, error_messages=ErrMessage.char("客户端类型"))
# 未查询到引用分段
no_references_setting = NoReferencesSetting(required=True, error_messages=ErrMessage.base("无引用分段设置"))

def is_valid(self, *, raise_exception=False):
super().is_valid(raise_exception=True)
Expand All @@ -92,5 +94,6 @@ def execute(self, message_list: List[BaseMessage],
chat_model: BaseChatModel = None,
paragraph_list=None,
manage: PipelineManage = None,
padding_problem_text: str = None, stream: bool = True, client_id=None, client_type=None, **kwargs):
padding_problem_text: str = None, stream: bool = True, client_id=None, client_type=None,
no_references_setting=None, **kwargs):
pass
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
from django.http import StreamingHttpResponse
from langchain.chat_models.base import BaseChatModel
from langchain.schema import BaseMessage
from langchain.schema.messages import BaseMessageChunk, HumanMessage, AIMessage
from langchain.schema.messages import HumanMessage, AIMessage
from langchain_core.messages import AIMessageChunk

from application.chat_pipeline.I_base_chat_pipeline import ParagraphPipelineModel
from application.chat_pipeline.pipeline_manage import PipelineManage
Expand Down Expand Up @@ -47,7 +48,8 @@ def event_content(response,
message_list: List[BaseMessage],
problem_text: str,
padding_problem_text: str = None,
client_id=None, client_type=None):
client_id=None, client_type=None,
is_ai_chat: bool = None):
all_text = ''
try:
for chunk in response:
Expand All @@ -56,8 +58,12 @@ def event_content(response,
'content': chunk.content, 'is_end': False}) + "\n\n"

# 获取token
request_token = chat_model.get_num_tokens_from_messages(message_list)
response_token = chat_model.get_num_tokens(all_text)
if is_ai_chat:
request_token = chat_model.get_num_tokens_from_messages(message_list)
response_token = chat_model.get_num_tokens(all_text)
else:
request_token = 0
response_token = 0
step.context['message_tokens'] = request_token
step.context['answer_tokens'] = response_token
current_time = time.time()
Expand Down Expand Up @@ -88,15 +94,16 @@ def execute(self, message_list: List[BaseMessage],
padding_problem_text: str = None,
stream: bool = True,
client_id=None, client_type=None,
no_references_setting=None,
**kwargs):
if stream:
return self.execute_stream(message_list, chat_id, problem_text, post_response_handler, chat_model,
paragraph_list,
manage, padding_problem_text, client_id, client_type)
manage, padding_problem_text, client_id, client_type, no_references_setting)
else:
return self.execute_block(message_list, chat_id, problem_text, post_response_handler, chat_model,
paragraph_list,
manage, padding_problem_text, client_id, client_type)
manage, padding_problem_text, client_id, client_type, no_references_setting)

def get_details(self, manage, **kwargs):
return {
Expand Down Expand Up @@ -127,19 +134,26 @@ def execute_stream(self, message_list: List[BaseMessage],
paragraph_list=None,
manage: PipelineManage = None,
padding_problem_text: str = None,
client_id=None, client_type=None):
client_id=None, client_type=None,
no_references_setting=None):
is_ai_chat = False
# 调用模型
if chat_model is None:
chat_result = iter(
[BaseMessageChunk(content=paragraph.title + "\n" + paragraph.content) for paragraph in paragraph_list])
[AIMessageChunk(content=paragraph.title + "\n" + paragraph.content) for paragraph in paragraph_list])
else:
chat_result = chat_model.stream(message_list)
if (paragraph_list is None or len(paragraph_list) == 0) and no_references_setting.get(
'status') == 'designated_answer':
chat_result = iter([AIMessageChunk(content=no_references_setting.get('value'))])
else:
chat_result = chat_model.stream(message_list)
is_ai_chat = True

chat_record_id = uuid.uuid1()
r = StreamingHttpResponse(
streaming_content=event_content(chat_result, chat_id, chat_record_id, paragraph_list,
post_response_handler, manage, self, chat_model, message_list, problem_text,
padding_problem_text, client_id, client_type),
padding_problem_text, client_id, client_type, is_ai_chat),
content_type='text/event-stream;charset=utf-8')

r['Cache-Control'] = 'no-cache'
Expand All @@ -153,16 +167,26 @@ def execute_block(self, message_list: List[BaseMessage],
paragraph_list=None,
manage: PipelineManage = None,
padding_problem_text: str = None,
client_id=None, client_type=None):
client_id=None, client_type=None, no_references_setting=None):
is_ai_chat = False
# 调用模型
if chat_model is None:
chat_result = AIMessage(
content="\n\n".join([paragraph.title + "\n" + paragraph.content for paragraph in paragraph_list]))
else:
chat_result = chat_model.invoke(message_list)
if (paragraph_list is None or len(paragraph_list) == 0) and no_references_setting.get(
'status') == 'designated_answer':
chat_result = AIMessage(content=no_references_setting.get('value'))
else:
chat_result = chat_model.invoke(message_list)
is_ai_chat = True
chat_record_id = uuid.uuid1()
request_token = chat_model.get_num_tokens_from_messages(message_list)
response_token = chat_model.get_num_tokens(chat_result.content)
if is_ai_chat:
request_token = chat_model.get_num_tokens_from_messages(message_list)
response_token = chat_model.get_num_tokens(chat_result.content)
else:
request_token = 0
response_token = 0
self.context['message_tokens'] = request_token
self.context['answer_tokens'] = response_token
current_time = time.time()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@
from application.chat_pipeline.I_base_chat_pipeline import IBaseChatPipelineStep, ParagraphPipelineModel
from application.chat_pipeline.pipeline_manage import PipelineManage
from application.models import ChatRecord
from application.serializers.application_serializers import NoReferencesSetting
from common.field.common import InstanceField
from common.util.field_message import ErrMessage
from dataset.models import Paragraph


class IGenerateHumanMessageStep(IBaseChatPipelineStep):
Expand All @@ -39,6 +39,8 @@ class InstanceSerializer(serializers.Serializer):
prompt = serializers.CharField(required=True, error_messages=ErrMessage.char("提示词"))
# 补齐问题
padding_problem_text = serializers.CharField(required=False, error_messages=ErrMessage.char("补齐问题"))
# 未查询到引用分段
no_references_setting = NoReferencesSetting(required=True, error_messages=ErrMessage.base("无引用分段设置"))

def get_step_serializer(self, manage: PipelineManage) -> Type[serializers.Serializer]:
return self.InstanceSerializer
Expand All @@ -56,6 +58,7 @@ def execute(self,
max_paragraph_char_number: int,
prompt: str,
padding_problem_text: str = None,
no_references_setting=None,
**kwargs) -> List[BaseMessage]:
"""
Expand All @@ -67,6 +70,7 @@ def execute(self,
:param prompt: 模板
:param padding_problem_text 用户修改文本
:param kwargs: 其他参数
:param no_references_setting: 无引用分段设置
:return:
"""
pass
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
@date:2024/1/10 17:50
@desc:
"""
from typing import List
from typing import List, Dict

from langchain.schema import BaseMessage, HumanMessage

Expand All @@ -26,22 +26,31 @@ def execute(self, problem_text: str,
max_paragraph_char_number: int,
prompt: str,
padding_problem_text: str = None,
no_references_setting=None,
**kwargs) -> List[BaseMessage]:
prompt = prompt if no_references_setting.get('status') == 'designated_answer' else no_references_setting.get(
'value')
exec_problem_text = padding_problem_text if padding_problem_text is not None else problem_text
start_index = len(history_chat_record) - dialogue_number
history_message = [[history_chat_record[index].get_human_message(), history_chat_record[index].get_ai_message()]
for index in
range(start_index if start_index > 0 else 0, len(history_chat_record))]
return [*flat_map(history_message),
self.to_human_message(prompt, exec_problem_text, max_paragraph_char_number, paragraph_list)]
self.to_human_message(prompt, exec_problem_text, max_paragraph_char_number, paragraph_list,
no_references_setting)]

@staticmethod
def to_human_message(prompt: str,
problem: str,
max_paragraph_char_number: int,
paragraph_list: List[ParagraphPipelineModel]):
paragraph_list: List[ParagraphPipelineModel],
no_references_setting: Dict):
if paragraph_list is None or len(paragraph_list) == 0:
return HumanMessage(content=prompt.format(**{'data': "<data></data>", 'question': problem}))
if no_references_setting.get('status') == 'ai_questioning':
return HumanMessage(
content=no_references_setting.get('value').format(**{'question': problem}))
else:
return HumanMessage(content=prompt.format(**{'data': "", 'question': problem}))
temp_data = ""
data_list = []
for p in paragraph_list:
Expand Down
6 changes: 5 additions & 1 deletion apps/application/models/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,11 @@


def get_dataset_setting_dict():
return {'top_n': 3, 'similarity': 0.6, 'max_paragraph_char_number': 5000, 'search_mode': 'embedding'}
return {'top_n': 3, 'similarity': 0.6, 'max_paragraph_char_number': 5000, 'search_mode': 'embedding',
'no_references_setting': {
'status': 'ai_questioning',
'value': '{question}'
}}


def get_model_setting_dict():
Expand Down
18 changes: 17 additions & 1 deletion apps/application/serializers/application_serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,18 @@ class Meta:
fields = "__all__"


class NoReferencesChoices(models.TextChoices):
"""订单类型"""
ai_questioning = 'ai_questioning', 'ai回答'
designated_answer = 'designated_answer', '指定回答'


class NoReferencesSetting(serializers.Serializer):
status = serializers.ChoiceField(required=True, choices=NoReferencesChoices.choices,
error_messages=ErrMessage.char("无引用状态"))
value = serializers.CharField(required=True, error_messages=ErrMessage.char("提示词"))


class DatasetSettingSerializer(serializers.Serializer):
top_n = serializers.FloatField(required=True, max_value=100, min_value=1,
error_messages=ErrMessage.float("引用分段数"))
Expand All @@ -85,6 +97,8 @@ class DatasetSettingSerializer(serializers.Serializer):
message="类型只支持register|reset_password", code=500)
], error_messages=ErrMessage.char("检索模式"))

no_references_setting = NoReferencesSetting(required=True, error_messages=ErrMessage.base("未引用分段设置"))


class ModelSettingSerializer(serializers.Serializer):
prompt = serializers.CharField(required=True, max_length=2048, error_messages=ErrMessage.char("提示词"))
Expand Down Expand Up @@ -383,7 +397,9 @@ def reset_application(application: Dict):
application['multiple_rounds_dialogue'] = True if application.get('dialogue_number') > 0 else False
del application['dialogue_number']
if 'dataset_setting' in application:
application['dataset_setting'] = {**application['dataset_setting'], 'search_mode': 'embedding'}
application['dataset_setting'] = {'search_mode': 'embedding', 'no_references_setting': {
'status': 'ai_questioning',
'value': '{question}'}, **application['dataset_setting']}
return application

def page(self, current_page: int, page_size: int, with_valid=True):
Expand Down
6 changes: 5 additions & 1 deletion apps/application/serializers/chat_message_serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,11 @@ def to_base_pipeline_manage_params(self):
'problem_optimization': self.application.problem_optimization,
'stream': True,
'search_mode': self.application.dataset_setting.get(
'search_mode') if 'search_mode' in self.application.dataset_setting else 'embedding'
'search_mode') if 'search_mode' in self.application.dataset_setting else 'embedding',
'no_references_setting': self.application.dataset_setting.get(
'no_references_setting') if 'no_references_setting' in self.application.dataset_setting else {
'status': 'ai_questioning',
'value': '{question}'}

}

Expand Down
12 changes: 12 additions & 0 deletions apps/application/swagger_api/application_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,18 @@ def get_request_body_api():
description="最多引用字符数", default=3000),
'search_mode': openapi.Schema(type=openapi.TYPE_STRING, title='检索模式',
description="embedding|keywords|blend", default='embedding'),
'no_references_setting': openapi.Schema(type=openapi.TYPE_OBJECT, title='检索模式',
required=['status', 'value'],
properties={
'status': openapi.Schema(type=openapi.TYPE_STRING,
title="状态",
description="ai作答:ai_questioning,指定回答:designated_answer",
default='ai_questioning'),
'value': openapi.Schema(type=openapi.TYPE_STRING,
title="值",
description="ai作答:就是题词,指定回答:就是指定回答内容",
default='{question}'),
}),
}
)

Expand Down
16 changes: 16 additions & 0 deletions ui/src/styles/element-plus.scss
Original file line number Diff line number Diff line change
Expand Up @@ -335,3 +335,19 @@
.auto-tooltip-popper {
max-width: 500px;
}

// radio 一行一个样式
.radio-block {
width: 100%;
display: block;
.el-radio {
align-items: flex-start;
height: 100%;
width: 100%;
}
.el-radio__label {
width: 100%;
margin-top: -8px;
line-height: 30px;
}
}
18 changes: 3 additions & 15 deletions ui/src/views/application-overview/component/EditAvatarDialog.vue
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
<template>
<el-dialog title="设置 Logo" v-model="dialogVisible">
<el-radio-group v-model="radioType" class="card__block mb-16">
<el-radio-group v-model="radioType" class="radio-block mb-16">
<div>
<el-radio value="default">
<p>默认 Logo</p>
Expand All @@ -14,7 +14,7 @@
/>
</el-radio>
</div>
<div>
<div class="mt-8">
<el-radio value="custom">
<p>自定义上传</p>
<div class="flex mt-8">
Expand Down Expand Up @@ -126,16 +126,4 @@ function submit() {
defineExpose({ open })
</script>
<style lang="scss" scope>
.card__block {
width: 100%;
display: block;
.el-radio {
align-items: flex-start;
height: 100%;
}
.el-radio__inner {
margin-top: 3px;
}
}
</style>
<style lang="scss" scope></style>
6 changes: 5 additions & 1 deletion ui/src/views/application/CreateAndSetting.vue
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,11 @@ const applicationForm = ref<ApplicationFormType>({
top_n: 3,
similarity: 0.6,
max_paragraph_char_number: 5000,
search_mode: 'embedding'
search_mode: 'embedding',
no_references_setting: {
status: 'ai_questioning',
value: '{question}'
}
},
model_setting: {
prompt: defaultPrompt
Expand Down
Loading

0 comments on commit 1f522dc

Please sign in to comment.