-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathDocChat.py
101 lines (83 loc) · 2.91 KB
/
DocChat.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
from langchain.embeddings import SentenceTransformerEmbeddings #HuggingFaceInstructEmbeddings
from langchain.vectorstores import FAISS
import os
import copy
import pprint
#import google.generativeai as palm
from langchain.llms import GooglePalm
from langchain import PromptTemplate
from langchain.chains import RetrievalQA
import streamlit as st
import warnings
warnings.filterwarnings("ignore")
@st.cache_resource
def getapi():
return str(open("API.txt","r",encoding='utf-8').read())
PALM_API=getapi()
#palm.configure(api_key=PALM_API)
@st.cache_resource
def getmodel():
"test"
embeddings = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2")
db = FAISS.load_local("faiss", embeddings)
retriever = db.as_retriever(search_kwargs={'k': 10})
#prompt=getprompt()
llm=GooglePalm(google_api_key=PALM_API,temperature=0,max_output_tokens=512)
qa_llm = RetrievalQA.from_chain_type(llm=llm,
chain_type='refine',
retriever=retriever,
return_source_documents=True,
#chain_type_kwargs={'prompt': prompt},
verbose=True)
return qa_llm
@st.cache_resource
def getprompt():
template = """Use the information to elaborate in points about the user's query.
If user mentions something not in the 'Context', just answer that you don't know.
If you don't know the answer, just say that you don't know, don't try to make up an answer.
Context: {context}
Query: {question}
Only return the helpful answer below and nothing else.
Helpful answer:
"""
prompt = PromptTemplate(
template=template,
input_variables=['context', 'question'])
return prompt
def parseresult(result):
PARSED=copy.deepcopy(result)
docs=PARSED['source_documents']
sourcepage=[]
for d in docs:
sourcepage.append(d.metadata['page'])
PARSED['source_pages']=copy.deepcopy(sourcepage)
del sourcepage,result
return PARSED
def getsources(result):
sources=[]
for s in result['source_documents']:
sources.append(f"{s.metadata}")
return sources
st.title('Query Docs')
prompt=st.sidebar.text_input("Enter query")
try:
llm=getmodel()
except:
st.write("CANNOT LOAD MODEL OR DATABASE")
#print("ERROR LOADING MODEL OR DATABASE")
if prompt:
if prompt.find("exit")==0:
import sys
sys.exit()
try:
result=parseresult(llm(prompt))
sources=getsources(result)
result=result["result"]
except:
result="Error in retrieving! \n You can try reframing your query, if it doesnt work there may be something broken. \n :/ "
sources=[]
print(">>>>>>>>>>>>><<<<<<<<<<<<<<<<<")
st.header("Result")
st.write(result)
st.header("Sources")
st.write(sources)