Skip to content

Commit

Permalink
Mistral-7B Support Added
Browse files Browse the repository at this point in the history
- Support for Mistral-7B Added to localGPT.
- Bug fix in the API code, will not delete the existing DB.
- Optimized the streamlit UI.
  • Loading branch information
PromtEngineer committed Oct 2, 2023
1 parent 15e9648 commit d7fef20
Show file tree
Hide file tree
Showing 6 changed files with 100 additions and 53 deletions.
8 changes: 7 additions & 1 deletion constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

# https://python.langchain.com/en/latest/modules/indexes/document_loaders/examples/excel.html?highlight=xlsx#microsoft-excel
from langchain.document_loaders import CSVLoader, PDFMinerLoader, TextLoader, UnstructuredExcelLoader, Docx2txtLoader
from langchain.document_loaders import UnstructuredFileLoader


# load_dotenv()
ROOT_DIRECTORY = os.path.dirname(os.path.realpath(__file__))
Expand Down Expand Up @@ -44,7 +46,8 @@
".txt": TextLoader,
".md": TextLoader,
".py": TextLoader,
".pdf": PDFMinerLoader,
# ".pdf": PDFMinerLoader,
".pdf": UnstructuredFileLoader,
".csv": CSVLoader,
".xls": UnstructuredExcelLoader,
".xlsx": UnstructuredExcelLoader,
Expand Down Expand Up @@ -98,6 +101,9 @@
MODEL_ID = "TheBloke/Llama-2-7b-Chat-GGUF"
MODEL_BASENAME = "llama-2-7b-chat.Q4_K_M.gguf"

# MODEL_ID = "TheBloke/Mistral-7B-Instruct-v0.1-GGUF"
# MODEL_BASENAME = "mistral-7b-instruct-v0.1.Q8_0.gguf"

# MODEL_ID = "TheBloke/Llama-2-70b-Chat-GGUF"
# MODEL_BASENAME = "llama-2-70b-chat.Q4_K_M.gguf"

Expand Down
63 changes: 33 additions & 30 deletions localGPT_UI.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from langchain.memory import ConversationBufferMemory



def model_memory():
# Adding history to the model.
template = """Use the following pieces of context to answer the question at the end. If you don't know the answer,\
Expand All @@ -28,33 +27,40 @@ def model_memory():

return prompt, memory


# Sidebar contents
with st.sidebar:
st.title('🤗💬 Converse with your Data')
st.markdown('''
st.title("🤗💬 Converse with your Data")
st.markdown(
"""
## About
This app is an LLM-powered chatbot built using:
- [Streamlit](https://streamlit.io/)
- [LangChain](https://python.langchain.com/)
- [LocalGPT](https://github.com/PromtEngineer/localGPT)
''')
"""
)
add_vertical_space(5)
st.write('Made with ❤️ by [Prompt Engineer](https://youtube.com/@engineerprompt)')

st.write("Made with ❤️ by [Prompt Engineer](https://youtube.com/@engineerprompt)")

DEVICE_TYPE = "cuda" if torch.cuda.is_available() else "cpu"

if torch.backends.mps.is_available():
DEVICE_TYPE = "mps"
elif torch.cuda.is_available():
DEVICE_TYPE = "cuda"
else:
DEVICE_TYPE = "cpu"


if "result" not in st.session_state:
# Run the document ingestion process.
run_langest_commands = ["python", "ingest.py"]
run_langest_commands.append("--device_type")
run_langest_commands.append(DEVICE_TYPE)
# if "result" not in st.session_state:
# # Run the document ingestion process.
# run_langest_commands = ["python", "ingest.py"]
# run_langest_commands.append("--device_type")
# run_langest_commands.append(DEVICE_TYPE)

result = subprocess.run(run_langest_commands, capture_output=True)
st.session_state.result = result
# result = subprocess.run(run_langest_commands, capture_output=True)
# st.session_state.result = result

# Define the retreiver
# load the vectorstore
Expand All @@ -79,41 +85,38 @@ def model_memory():
st.session_state["LLM"] = LLM




if "QA" not in st.session_state:

prompt, memory = model_memory()

QA = RetrievalQA.from_chain_type(
llm=LLM,
chain_type="stuff",
retriever=RETRIEVER,
llm=LLM,
chain_type="stuff",
retriever=RETRIEVER,
return_source_documents=True,
chain_type_kwargs={"prompt": prompt, "memory": memory},
)
st.session_state["QA"] = QA

st.title('LocalGPT App 💬')
# Create a text input box for the user
prompt = st.text_input('Input your prompt here')
st.title("LocalGPT App 💬")
# Create a text input box for the user
prompt = st.text_input("Input your prompt here")
# while True:

# If the user hits enter
# If the user hits enter
if prompt:
# Then pass the prompt to the LLM
response = st.session_state["QA"](prompt)
answer, docs = response["result"], response["source_documents"]
# ...and write it out to the screen
st.write(answer)

# With a streamlit expander
with st.expander('Document Similarity Search'):
# With a streamlit expander
with st.expander("Document Similarity Search"):
# Find the relevant pages
search = st.session_state.DB.similarity_search_with_score(prompt)
search = st.session_state.DB.similarity_search_with_score(prompt)
# Write out the first
for i, doc in enumerate(search):
for i, doc in enumerate(search):
# print(doc)
st.write(f"Source Document # {i+1} : {doc[0].metadata['source'].split('/')[-1]}")
st.write(doc[0].page_content)
st.write("--------------------------------")
st.write(doc[0].page_content)
st.write("--------------------------------")
25 changes: 24 additions & 1 deletion prompt_template_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,30 @@ def get_prompt_template(system_prompt=system_prompt, promptTemplate_type=None, h

prompt_template = B_INST + SYSTEM_PROMPT + instruction + E_INST
prompt = PromptTemplate(input_variables=["context", "question"], template=prompt_template)

elif promptTemplate_type == "mistral":
B_INST, E_INST = "<s>[INST] ", " [/INST]"
if history:
prompt_template = (
B_INST
+ system_prompt
+ """
Context: {history} \n {context}
User: {question}"""
+ E_INST
)
prompt = PromptTemplate(input_variables=["history", "context", "question"], template=prompt_template)
else:
prompt_template = (
B_INST
+ system_prompt
+ """
Context: {context}
User: {question}"""
+ E_INST
)
prompt = PromptTemplate(input_variables=["context", "question"], template=prompt_template)
else:
# change this based on the model you have selected.
if history:
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ protobuf==3.20.3; sys_platform == 'darwin' and platform_machine == 'arm64'
auto-gptq==0.2.2
docx2txt
unstructured
unstructured[pdf]

# Utilities
urllib3==1.26.6
Expand Down
12 changes: 10 additions & 2 deletions run_localGPT.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,15 @@ def retrieval_qa_pipline(device_type, use_history, promptTemplate_type="llama"):
is_flag=True,
help="Use history (Default is False)",
)
def main(device_type, show_sources, use_history):
@click.option(
"--model_type",
default="llama",
type=click.Choice(
["llama", "mistral", "non_llama"],
),
help="model type, llama, mistral or non_llama",
)
def main(device_type, show_sources, use_history, model_type):
"""
Implements the main information retrieval task for a localGPT.
Expand Down Expand Up @@ -226,7 +234,7 @@ def main(device_type, show_sources, use_history):
if not os.path.exists(MODELS_PATH):
os.mkdir(MODELS_PATH)

qa = retrieval_qa_pipline(device_type, use_history, promptTemplate_type="llama")
qa = retrieval_qa_pipline(device_type, use_history, promptTemplate_type=model_type)
# Interactive questions and answers
while True:
query = input("\nEnter a query: ")
Expand Down
44 changes: 25 additions & 19 deletions run_localGPT_API.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,13 @@

from constants import CHROMA_SETTINGS, EMBEDDING_MODEL_NAME, PERSIST_DIRECTORY, MODEL_ID, MODEL_BASENAME

DEVICE_TYPE = "cuda" if torch.cuda.is_available() else "cpu"
if torch.backends.mps.is_available():
DEVICE_TYPE = "mps"
elif torch.cuda.is_available():
DEVICE_TYPE = "cuda"
else:
DEVICE_TYPE = "cpu"

SHOW_SOURCES = True
logging.info(f"Running on: {DEVICE_TYPE}")
logging.info(f"Display Source Documents set to: {SHOW_SOURCES}")
Expand All @@ -27,24 +33,24 @@

# uncomment the following line if you used HuggingFaceEmbeddings in the ingest.py
# EMBEDDINGS = HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL_NAME)
if os.path.exists(PERSIST_DIRECTORY):
try:
shutil.rmtree(PERSIST_DIRECTORY)
except OSError as e:
print(f"Error: {e.filename} - {e.strerror}.")
else:
print("The directory does not exist")

run_langest_commands = ["python", "ingest.py"]
if DEVICE_TYPE == "cpu":
run_langest_commands.append("--device_type")
run_langest_commands.append(DEVICE_TYPE)

result = subprocess.run(run_langest_commands, capture_output=True)
if result.returncode != 0:
raise FileNotFoundError(
"No files were found inside SOURCE_DOCUMENTS, please put a starter file inside before starting the API!"
)
# if os.path.exists(PERSIST_DIRECTORY):
# try:
# shutil.rmtree(PERSIST_DIRECTORY)
# except OSError as e:
# print(f"Error: {e.filename} - {e.strerror}.")
# else:
# print("The directory does not exist")

# run_langest_commands = ["python", "ingest.py"]
# if DEVICE_TYPE == "cpu":
# run_langest_commands.append("--device_type")
# run_langest_commands.append(DEVICE_TYPE)

# result = subprocess.run(run_langest_commands, capture_output=True)
# if result.returncode != 0:
# raise FileNotFoundError(
# "No files were found inside SOURCE_DOCUMENTS, please put a starter file inside before starting the API!"
# )

# load the vectorstore
DB = Chroma(
Expand Down

0 comments on commit d7fef20

Please sign in to comment.