-
Notifications
You must be signed in to change notification settings - Fork 178
/
Copy pathpipeline.py
105 lines (85 loc) · 3.35 KB
/
pipeline.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import PromptTemplate
from transformers.utils import logging
from rag.src.data_processing import Data_process
from rag.src.config.config import prompt_template
logger = logging.get_logger(__name__)
class EmoLLMRAG(object):
"""
EmoLLM RAG Pipeline
1. 根据 query 进行 embedding
2. 从 vector DB 中检索数据
3. rerank 检索后的结果
4. 将 query 和检索回来的 content 传入 LLM 中
"""
def __init__(self, model, retrieval_num=3, rerank_flag=False, select_num=3) -> None:
"""
输入 Model 进行初始化
DataProcessing obj: 进行数据处理,包括数据 embedding/rerank
vectorstores: 加载vector DB。如果没有应该重新创建
system prompt: 获取预定义的 system prompt
prompt template: 定义最后的输入到 LLM 中的 template
"""
self.model = model
self.data_processing_obj = Data_process()
self.vectorstores = self._load_vector_db()
self.prompt_template = prompt_template
self.retrieval_num = retrieval_num
self.rerank_flag = rerank_flag
self.select_num = select_num
def _load_vector_db(self):
"""
调用 embedding 模块给出接口 load vector DB
"""
vectorstores = self.data_processing_obj.load_vector_db()
return vectorstores
def get_retrieval_content(self, query) -> str:
"""
Input: 用户提问, 是否需要rerank
output: 检索后并且 rerank 的内容
"""
content = []
documents = self.vectorstores.similarity_search(query, k=self.retrieval_num)
for doc in documents:
content.append(doc.page_content)
# 如果需要rerank,调用接口对 documents 进行 rerank
if self.rerank_flag:
documents, _ = self.data_processing_obj.rerank(documents, self.select_num)
content = []
for doc in documents:
content.append(doc)
logger.info(f'Retrieval data: {content}')
return content
def generate_answer(self, query, content) -> str:
"""
Input: 用户提问, 检索返回的内容
Output: 模型生成结果
"""
# 构建 template
# 第一版不涉及 history 信息,因此将 system prompt 直接纳入到 template 之中
prompt = PromptTemplate(
template=self.prompt_template,
input_variables=["query", "content"],
)
# 定义 chain
# output格式为 string
rag_chain = prompt | self.model | StrOutputParser()
# Run
generation = rag_chain.invoke(
{
"query": query,
"content": content,
}
)
return generation
def main(self, query) -> str:
"""
Input: 用户提问
output: LLM 生成的结果
定义整个 RAG 的 pipeline 流程,调度各个模块
TODO:
加入 RAGAS 评分系统
"""
content = self.get_retrieval_content(query)
response = self.generate_answer(query, content)
return response