Skip to content

Commit

Permalink
feat: 增加全文检索和混合检索方式
Browse files Browse the repository at this point in the history
  • Loading branch information
shaohuzhang1 authored Apr 22, 2024
1 parent 8fe1a14 commit c89ae29
Show file tree
Hide file tree
Showing 20 changed files with 581 additions and 141 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,16 @@
@date:2024/1/9 18:10
@desc: 检索知识库
"""
import re
from abc import abstractmethod
from typing import List, Type

from django.core import validators
from rest_framework import serializers

from application.chat_pipeline.I_base_chat_pipeline import IBaseChatPipelineStep, ParagraphPipelineModel
from application.chat_pipeline.pipeline_manage import PipelineManage
from common.util.field_message import ErrMessage
from dataset.models import Paragraph


class ISearchDatasetStep(IBaseChatPipelineStep):
Expand All @@ -38,6 +39,10 @@ class InstanceSerializer(serializers.Serializer):
# 相似度 0-1之间
similarity = serializers.FloatField(required=True, max_value=1, min_value=0,
error_messages=ErrMessage.float("引用分段数"))
search_mode = serializers.CharField(required=True, validators=[
validators.RegexValidator(regex=re.compile("^embedding|keywords|blend$"),
message="类型只支持register|reset_password", code=500)
], error_messages=ErrMessage.char("检索模式"))

def get_step_serializer(self, manage: PipelineManage) -> Type[InstanceSerializer]:
return self.InstanceSerializer
Expand All @@ -50,6 +55,7 @@ def _run(self, manage: PipelineManage):
@abstractmethod
def execute(self, problem_text: str, dataset_id_list: list[str], exclude_document_id_list: list[str],
exclude_paragraph_id_list: list[str], top_n: int, similarity: float, padding_problem_text: str = None,
search_mode: str = None,
**kwargs) -> List[ParagraphPipelineModel]:
"""
关于 用户和补全问题 说明: 补全问题如果有就使用补全问题去查询 反之就用用户原始问题查询
Expand All @@ -60,6 +66,7 @@ def execute(self, problem_text: str, dataset_id_list: list[str], exclude_documen
:param exclude_document_id_list: 需要排除的文档id
:param exclude_paragraph_id_list: 需要排除段落id
:param padding_problem_text 补全问题
:param search_mode 检索模式
:return: 段落列表
"""
pass
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,22 @@
from common.db.search import native_search
from common.util.file_util import get_file_content
from dataset.models import Paragraph
from embedding.models import SearchMode
from smartdoc.conf import PROJECT_DIR


class BaseSearchDatasetStep(ISearchDatasetStep):

def execute(self, problem_text: str, dataset_id_list: list[str], exclude_document_id_list: list[str],
exclude_paragraph_id_list: list[str], top_n: int, similarity: float, padding_problem_text: str = None,
search_mode: str = None,
**kwargs) -> List[ParagraphPipelineModel]:
exec_problem_text = padding_problem_text if padding_problem_text is not None else problem_text
embedding_model = EmbeddingModel.get_embedding_model()
embedding_value = embedding_model.embed_query(exec_problem_text)
vector = VectorStore.get_embedding_vector()
embedding_list = vector.query(embedding_value, dataset_id_list, exclude_document_id_list,
exclude_paragraph_id_list, True, top_n, similarity)
embedding_list = vector.query(exec_problem_text, embedding_value, dataset_id_list, exclude_document_id_list,
exclude_paragraph_id_list, True, top_n, similarity, SearchMode(search_mode))
if embedding_list is None:
return []
paragraph_list = self.list_paragraph([row.get('paragraph_id') for row in embedding_list], vector)
Expand Down
2 changes: 1 addition & 1 deletion apps/application/models/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@


def get_dataset_setting_dict():
return {'top_n': 3, 'similarity': 0.6, 'max_paragraph_char_number': 5000}
return {'top_n': 3, 'similarity': 0.6, 'max_paragraph_char_number': 5000, 'search_mode': 'embedding'}


def get_model_setting_dict():
Expand Down
13 changes: 12 additions & 1 deletion apps/application/serializers/application_serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,13 @@
"""
import hashlib
import os
import re
import uuid
from functools import reduce
from typing import Dict

from django.contrib.postgres.fields import ArrayField
from django.core import cache
from django.core import cache, validators
from django.core import signing
from django.db import transaction, models
from django.db.models import QuerySet
Expand All @@ -32,6 +33,7 @@
from common.util.file_util import get_file_content
from dataset.models import DataSet, Document
from dataset.serializers.common_serializers import list_paragraph
from embedding.models import SearchMode
from setting.models import AuthOperate
from setting.models.model_management import Model
from setting.models_provider.constants.model_provider_constants import ModelProvideConstants
Expand Down Expand Up @@ -77,6 +79,10 @@ class DatasetSettingSerializer(serializers.Serializer):
error_messages=ErrMessage.float("相识度"))
max_paragraph_char_number = serializers.IntegerField(required=True, min_value=500, max_value=10000,
error_messages=ErrMessage.integer("最多引用字符数"))
search_mode = serializers.CharField(required=True, validators=[
validators.RegexValidator(regex=re.compile("^embedding|keywords|blend$"),
message="类型只支持register|reset_password", code=500)
], error_messages=ErrMessage.char("检索模式"))


class ModelSettingSerializer(serializers.Serializer):
Expand Down Expand Up @@ -291,6 +297,10 @@ class HitTest(serializers.Serializer):
error_messages=ErrMessage.integer("topN"))
similarity = serializers.FloatField(required=True, max_value=1, min_value=0,
error_messages=ErrMessage.float("相关度"))
search_mode = serializers.CharField(required=True, validators=[
validators.RegexValidator(regex=re.compile("^embedding|keywords|blend$"),
message="类型只支持register|reset_password", code=500)
], error_messages=ErrMessage.char("检索模式"))

def is_valid(self, *, raise_exception=False):
super().is_valid(raise_exception=True)
Expand All @@ -312,6 +322,7 @@ def hit_test(self):
hit_list = vector.hit_test(self.data.get('query_text'), dataset_id_list, exclude_document_id_list,
self.data.get('top_number'),
self.data.get('similarity'),
SearchMode(self.data.get('search_mode')),
EmbeddingModel.get_embedding_model())
hit_dict = reduce(lambda x, y: {**x, **y}, [{hit.get('paragraph_id'): hit} for hit in hit_list], {})
p_list = list_paragraph([h.get('paragraph_id') for h in hit_list])
Expand Down
8 changes: 5 additions & 3 deletions apps/application/serializers/chat_message_serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,8 @@ def to_base_pipeline_manage_params(self):
'model_id': self.application.model.id if self.application.model is not None else None,
'problem_optimization': self.application.problem_optimization,
'stream': True,
'search_mode': self.application.dataset_setting.get(
'search_mode') if 'search_mode' in self.application.dataset_setting else 'embedding'

}

Expand Down Expand Up @@ -184,9 +186,9 @@ def chat(self):
pipeline_manage_builder.append_step(BaseResetProblemStep)
# 构建流水线管理器
pipeline_message = (pipeline_manage_builder.append_step(BaseSearchDatasetStep)
.append_step(BaseGenerateHumanMessageStep)
.append_step(BaseChatStep)
.build())
.append_step(BaseGenerateHumanMessageStep)
.append_step(BaseChatStep)
.build())
exclude_paragraph_id_list = []
# 相同问题是否需要排除已经查询到的段落
if re_chat:
Expand Down
2 changes: 2 additions & 0 deletions apps/application/swagger_api/application_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,8 @@ def get_request_body_api():
default=0.6),
'max_paragraph_char_number': openapi.Schema(type=openapi.TYPE_NUMBER, title='最多引用字符数',
description="最多引用字符数", default=3000),
'search_mode': openapi.Schema(type=openapi.TYPE_STRING, title='检索模式',
description="embedding|keywords|blend", default='embedding'),
}
)

Expand Down
3 changes: 2 additions & 1 deletion apps/application/views/application_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,8 @@ def get(self, request: Request, application_id: str):
ApplicationSerializer.HitTest(data={'id': application_id, 'user_id': request.user.id,
"query_text": request.query_params.get("query_text"),
"top_number": request.query_params.get("top_number"),
'similarity': request.query_params.get('similarity')}).hit_test(
'similarity': request.query_params.get('similarity'),
'search_mode': request.query_params.get('search_mode')}).hit_test(
))

class Operate(APIView):
Expand Down
7 changes: 7 additions & 0 deletions apps/common/swagger_api/common_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,13 @@ def get_request_params_api():
default=0.6,
required=True,
description='相关性'),
openapi.Parameter(name='search_mode',
in_=openapi.IN_QUERY,
type=openapi.TYPE_STRING,
default="embedding",
required=True,
description='检索模式embedding|keywords|blend'
)
]

@staticmethod
Expand Down
107 changes: 107 additions & 0 deletions apps/common/util/ts_vecto_util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
# coding=utf-8
"""
@project: maxkb
@Author:虎
@file: ts_vecto_util.py
@date:2024/4/16 15:26
@desc:
"""
import re
import uuid
from typing import List

import jieba
from jieba import analyse

from common.util.split_model import group_by

jieba_word_list_cache = [chr(item) for item in range(38, 84)]

for jieba_word in jieba_word_list_cache:
jieba.add_word('#' + jieba_word + '#')
# r"(?i)\b(?:https?|ftp|tcp|file)://[^\s]+\b",
# 某些不分词数据
# r'"([^"]*)"'
word_pattern_list = [r"v\d+.\d+.\d+",
r"[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}"]

remove_chars = '\n , :\'<>!@#¥%……&*()!@#$%^&*(): ;,/"./-'


def get_word_list(text: str):
result = []
for pattern in word_pattern_list:
word_list = re.findall(pattern, text)
for child_list in word_list:
for word in child_list if isinstance(child_list, tuple) else [child_list]:
# 不能有: 所以再使用: 进行分割
if word.__contains__(':'):
item_list = word.split(":")
for w in item_list:
result.append(w)
else:
result.append(word)
return result


def replace_word(word_dict, text: str):
for key in word_dict:
text = re.sub('(?<!#)' + word_dict[key] + '(?!#)', key, text)
return text


def get_word_key(text: str, use_word_list):
for j_word in jieba_word_list_cache:
if not text.__contains__(j_word) and not use_word_list.__contains__(j_word):
return j_word
j_word = str(uuid.uuid1())
jieba.add_word(j_word)
return j_word


def to_word_dict(word_list: List, text: str):
word_dict = {}
for word in word_list:
key = get_word_key(text, set(word_dict))
word_dict['#' + key + '#'] = word
return word_dict


def get_key_by_word_dict(key, word_dict):
v = word_dict.get(key)
if v is None:
return key
return v


def to_ts_vector(text: str):
# 获取不分词的数据
word_list = get_word_list(text)
# 获取关键词关系
word_dict = to_word_dict(word_list, text)
# 替换字符串
text = replace_word(word_dict, text)
# 分词
result = jieba.tokenize(text, mode='search')
result_ = [{'word': get_key_by_word_dict(item[0], word_dict), 'index': item[1]} for item in result]
result_group = group_by(result_, lambda r: r['word'])
return " ".join(
[f"{key.lower()}:{','.join([str(item['index'] + 1) for item in result_group[key]][:20])}" for key in
result_group if
not remove_chars.__contains__(key) and len(key.strip()) >= 0])


def to_query(text: str):
# 获取不分词的数据
word_list = get_word_list(text)
# 获取关键词关系
word_dict = to_word_dict(word_list, text)
# 替换字符串
text = replace_word(word_dict, text)
extract_tags = analyse.extract_tags(text, topK=5, withWeight=True, allowPOS=('ns', 'n', 'vn', 'v', 'eng'))
result = " ".join([get_key_by_word_dict(word, word_dict) for word, score in extract_tags if
not remove_chars.__contains__(word)])
# 删除词库
for word in word_list:
jieba.del_word(word)
return result
6 changes: 6 additions & 0 deletions apps/dataset/serializers/dataset_serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from dataset.models.data_set import DataSet, Document, Paragraph, Problem, Type, ProblemParagraphMapping
from dataset.serializers.common_serializers import list_paragraph, MetaSerializer
from dataset.serializers.document_serializers import DocumentSerializers, DocumentInstanceSerializer
from embedding.models import SearchMode
from setting.models import AuthOperate
from smartdoc.conf import PROJECT_DIR

Expand Down Expand Up @@ -457,6 +458,10 @@ class HitTest(ApiMixin, serializers.Serializer):
error_messages=ErrMessage.char("响应Top"))
similarity = serializers.FloatField(required=True, max_value=1, min_value=0,
error_messages=ErrMessage.char("相似度"))
search_mode = serializers.CharField(required=True, validators=[
validators.RegexValidator(regex=re.compile("^embedding|keywords|blend$"),
message="类型只支持register|reset_password", code=500)
], error_messages=ErrMessage.char("检索模式"))

def is_valid(self, *, raise_exception=True):
super().is_valid(raise_exception=True)
Expand All @@ -474,6 +479,7 @@ def hit_test(self):
hit_list = vector.hit_test(self.data.get('query_text'), [self.data.get('id')], exclude_document_id_list,
self.data.get('top_number'),
self.data.get('similarity'),
SearchMode(self.data.get('search_mode')),
EmbeddingModel.get_embedding_model())
hit_dict = reduce(lambda x, y: {**x, **y}, [{hit.get('paragraph_id'): hit} for hit in hit_list], {})
p_list = list_paragraph([h.get('paragraph_id') for h in hit_list])
Expand Down
3 changes: 2 additions & 1 deletion apps/dataset/views/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,8 @@ def get(self, request: Request, dataset_id: str):
DataSetSerializers.HitTest(data={'id': dataset_id, 'user_id': request.user.id,
"query_text": request.query_params.get("query_text"),
"top_number": request.query_params.get("top_number"),
'similarity': request.query_params.get('similarity')}).hit_test(
'similarity': request.query_params.get('similarity'),
'search_mode': request.query_params.get('search_mode')}).hit_test(
))

class Operate(APIView):
Expand Down
54 changes: 54 additions & 0 deletions apps/embedding/migrations/0002_embedding_search_vector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# Generated by Django 4.1.13 on 2024-04-16 11:43

import django.contrib.postgres.search
from django.db import migrations

from common.util.common import sub_array
from common.util.ts_vecto_util import to_ts_vector
from dataset.models import Status
from embedding.models import Embedding


def update_embedding_search_vector(embedding, paragraph_list):
paragraphs = [paragraph for paragraph in paragraph_list if paragraph.id == embedding.get('paragraph')]
if len(paragraphs) > 0:
content = paragraphs[0].title + paragraphs[0].content
return Embedding(id=embedding.get('id'), search_vector=to_ts_vector(content))
return Embedding(id=embedding.get('id'), search_vector="")


def save_keywords(apps, schema_editor):
document = apps.get_model("dataset", "Document")
embedding = apps.get_model("embedding", "Embedding")
paragraph = apps.get_model('dataset', 'Paragraph')
db_alias = schema_editor.connection.alias
document_list = document.objects.using(db_alias).all()
for document in document_list:
document.status = Status.embedding
document.save()
paragraph_list = paragraph.objects.using(db_alias).filter(document=document).all()
embedding_list = embedding.objects.using(db_alias).filter(document=document).values('id', 'search_vector',
'paragraph')
embedding_update_list = [update_embedding_search_vector(embedding, paragraph_list) for embedding
in embedding_list]
child_array = sub_array(embedding_update_list, 20)
for c in child_array:
try:
embedding.objects.using(db_alias).bulk_update(c, ['search_vector'])
except Exception as e:
print(e)


class Migration(migrations.Migration):
dependencies = [
('embedding', '0001_initial'),
]

operations = [
migrations.AddField(
model_name='embedding',
name='search_vector',
field=django.contrib.postgres.search.SearchVectorField(default='', verbose_name='分词'),
),
migrations.RunPython(save_keywords)
]
Loading

0 comments on commit c89ae29

Please sign in to comment.