forked from sudarshan-koirala/llama2-chat-with-documents
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
2681101
commit 6791573
Showing
6 changed files
with
342 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
# chat-with-website | ||
Simple Chainlit app to have interaction with your documents. | ||
|
||
### Chat with your documents 🚀 | ||
- [Huggingface model](https://huggingface.co/TheBloke/Llama-2-7B-Chat-GGML/blob/main/llama-2-7b-chat.ggmlv3.q8_0.bin) as Large Language model | ||
- [LangChain](https://python.langchain.com/docs/get_started/introduction.html) as a Framework for LLM | ||
- [Streamlit](https://docs.chainlit.io/overview) for deploying. | ||
|
||
## System Requirements | ||
|
||
You must have Python 3.9 or later installed. Earlier versions of python may not compile. | ||
|
||
--- | ||
|
||
## Steps to Replicate | ||
|
||
1. Fork this repository and create a codespace in GitHub as I showed you in the youtube video OR Clone it locally. | ||
``` | ||
git clone https://github.com/sudarshan-koirala/chat-with-website.git | ||
cd chat-with-website | ||
``` | ||
|
||
2. Rename example.env to .env with `cp example.env .env`and input the OpenAI API key as follows. Get OpenAI API key from this [URL](https://platform.openai.com/account/api-keys). You need to create an account in OpenAI webiste if you haven't already. | ||
``` | ||
OPENAI_API_KEY=your_openai_api_key | ||
``` | ||
|
||
3. Create a virtualenv and activate it | ||
``` | ||
python3 -m venv .venv && source .venv/bin/activate | ||
``` | ||
|
||
4. Run the following command in the terminal to install necessary python packages: | ||
``` | ||
pip install -r requirements.txt | ||
``` | ||
|
||
5. Run the following command in your terminal to start the chat UI: | ||
``` | ||
streamlit run chat_with_website.py | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
# Welcome to Chainlit! 🚀🤖 | ||
|
||
Hi, This is a simple Chainlit app to have chat with your documents. | ||
|
||
## Useful Links 🔗 | ||
|
||
**Youtube Playlist** | ||
Get started with [Chainlit Playlist](https://youtube.com/playlist?list=PLz-qytj7eIWWNnbCRxflmRbYI02jZeG0k) 🎥 | ||
Get started with [LangChain Playlist](https://youtube.com/playlist?list=PLz-qytj7eIWVd1a5SsQ1dzOjVDHdgC1Ck) 🎥 | ||
|
||
Happy coding! 💻😊 | ||
|
||
## If you want to support | ||
- [Buy me a coffee](https://ko-fi.com/datasciencebasics) && [Patreon](https://www.patreon.com/datasciencebasics) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
HUGGINGFACEHUB_API_TOKEN=hf_***** |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
import os | ||
|
||
from langchain.document_loaders import ( | ||
DirectoryLoader, | ||
PyPDFLoader, | ||
TextLoader, | ||
UnstructuredMarkdownLoader, | ||
) | ||
from langchain.embeddings import HuggingFaceEmbeddings | ||
from langchain.text_splitter import RecursiveCharacterTextSplitter | ||
from langchain.vectorstores import Chroma | ||
|
||
ABS_PATH: str = os.path.dirname(os.path.abspath(__file__)) | ||
DB_DIR: str = os.path.join(ABS_PATH, "db") | ||
|
||
|
||
# Create vector database | ||
def create_vector_database(): | ||
""" | ||
Creates a vector database using document loaders and embeddings. | ||
This function loads data from PDF, markdown and text files in the 'data/' directory, | ||
splits the loaded documents into chunks, transforms them into embeddings using HuggingFace, | ||
and finally persists the embeddings into a Chroma vector database. | ||
""" | ||
# Initialize loaders for different file types | ||
pdf_loader = DirectoryLoader("data/", glob="**/*.pdf", loader_cls=PyPDFLoader) | ||
markdown_loader = DirectoryLoader( | ||
"data/", glob="**/*.md", loader_cls=UnstructuredMarkdownLoader | ||
) | ||
text_loader = DirectoryLoader("data/", glob="**/*.txt", loader_cls=TextLoader) | ||
|
||
all_loaders = [pdf_loader, markdown_loader, text_loader] | ||
|
||
# Load documents from all loaders | ||
loaded_documents = [] | ||
for loader in all_loaders: | ||
loaded_documents.extend(loader.load()) | ||
|
||
# Split loaded documents into chunks | ||
text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=40) | ||
chunked_documents = text_splitter.split_documents(loaded_documents) | ||
|
||
# Initialize HuggingFace embeddings | ||
huggingface_embeddings = HuggingFaceEmbeddings( | ||
model_name="sentence-transformers/all-MiniLM-L6-v2", | ||
model_kwargs={"device": "cpu"}, | ||
) | ||
|
||
# Create and persist a Chroma vector database from the chunked documents | ||
vector_database = Chroma.from_documents( | ||
documents=chunked_documents, | ||
embedding=huggingface_embeddings, | ||
persist_directory=DB_DIR, | ||
) | ||
|
||
vector_database.persist() | ||
|
||
|
||
if __name__ == "__main__": | ||
create_vector_database() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,212 @@ | ||
import os | ||
|
||
import chainlit as cl | ||
from langchain.chains import RetrievalQA | ||
from langchain.embeddings import HuggingFaceEmbeddings | ||
from langchain.llms import CTransformers | ||
from langchain.prompts import PromptTemplate | ||
from langchain.vectorstores import Chroma | ||
|
||
prompt_template = """Use the following pieces of context to answer the users question. | ||
If you don't know the answer, just say that you don't know, don't try to make up an answer. | ||
ALWAYS return a "SOURCES" part in your answer. | ||
The "SOURCES" part should be a reference to the source of the document from which you got your answer. | ||
xample of your response should be: | ||
Context: {context} | ||
Question: {question} | ||
Only return the helpful answer below and nothing else. | ||
Helpful answer: | ||
""" | ||
|
||
|
||
def set_custom_prompt(): | ||
""" | ||
Prompt template for QA retrieval for each vectorstore | ||
""" | ||
prompt = PromptTemplate( | ||
template=prompt_template, input_variables=["context", "question"] | ||
) | ||
return prompt | ||
|
||
|
||
def create_retrieval_qa_chain(llm, prompt, db): | ||
""" | ||
Creates a Retrieval Question-Answering (QA) chain using a given language model, prompt, and database. | ||
This function initializes a RetrievalQA object with a specific chain type and configurations, | ||
and returns this QA chain. The retriever is set up to return the top 3 results (k=3). | ||
Args: | ||
llm (any): The language model to be used in the RetrievalQA. | ||
prompt (str): The prompt to be used in the chain type. | ||
db (any): The database to be used as the retriever. | ||
Returns: | ||
RetrievalQA: The initialized QA chain. | ||
""" | ||
qa_chain = RetrievalQA.from_chain_type( | ||
llm=llm, | ||
chain_type="stuff", | ||
retriever=db.as_retriever(search_kwargs={"k": 3}), | ||
return_source_documents=True, | ||
chain_type_kwargs={"prompt": prompt}, | ||
) | ||
return qa_chain | ||
|
||
|
||
def load_model( | ||
model_path="model/llama-2-7b-chat.ggmlv3.q8_0.bin", | ||
model_type="llama", | ||
max_new_tokens=512, | ||
temperature=0.5, | ||
): | ||
""" | ||
Load a locally downloaded model. | ||
Parameters: | ||
model_path (str): The path to the model to be loaded. | ||
model_type (str): The type of the model. | ||
max_new_tokens (int): The maximum number of new tokens for the model. | ||
temperature (float): The temperature parameter for the model. | ||
Returns: | ||
CTransformers: The loaded model. | ||
Raises: | ||
FileNotFoundError: If the model file does not exist. | ||
SomeOtherException: If the model file is corrupt. | ||
""" | ||
if not os.path.exists(model_path): | ||
raise FileNotFoundError(f"No model file found at {model_path}") | ||
|
||
# Additional error handling could be added here for corrupt files, etc. | ||
|
||
llm = CTransformers( | ||
model=model_path, | ||
model_type=model_type, | ||
max_new_tokens=max_new_tokens, # type: ignore | ||
temperature=temperature, # type: ignore | ||
) | ||
|
||
return llm | ||
|
||
|
||
def create_retrieval_qa_bot( | ||
model_name="sentence-transformers/all-MiniLM-L6-v2", | ||
persist_dir="./db", | ||
device="cpu", | ||
): | ||
""" | ||
This function creates a retrieval-based question-answering bot. | ||
Parameters: | ||
model_name (str): The name of the model to be used for embeddings. | ||
persist_dir (str): The directory to persist the database. | ||
device (str): The device to run the model on (e.g., 'cpu', 'cuda'). | ||
Returns: | ||
RetrievalQA: The retrieval-based question-answering bot. | ||
Raises: | ||
FileNotFoundError: If the persist directory does not exist. | ||
SomeOtherException: If there is an issue with loading the embeddings or the model. | ||
""" | ||
|
||
if not os.path.exists(persist_dir): | ||
raise FileNotFoundError(f"No directory found at {persist_dir}") | ||
|
||
try: | ||
embeddings = HuggingFaceEmbeddings( | ||
model_name=model_name, | ||
model_kwargs={"device": device}, | ||
) | ||
except Exception as e: | ||
raise Exception( | ||
f"Failed to load embeddings with model name {model_name}: {str(e)}" | ||
) | ||
|
||
db = Chroma(persist_directory=persist_dir, embedding_function=embeddings) | ||
|
||
try: | ||
llm = load_model() # Assuming this function exists and works as expected | ||
except Exception as e: | ||
raise Exception(f"Failed to load model: {str(e)}") | ||
|
||
qa_prompt = ( | ||
set_custom_prompt() | ||
) # Assuming this function exists and works as expected | ||
|
||
try: | ||
qa = create_retrieval_qa_chain( | ||
llm=llm, prompt=qa_prompt, db=db | ||
) # Assuming this function exists and works as expected | ||
except Exception as e: | ||
raise Exception(f"Failed to create retrieval QA chain: {str(e)}") | ||
|
||
return qa | ||
|
||
|
||
def retrieve_bot_answer(query): | ||
""" | ||
Retrieves the answer to a given query using a QA bot. | ||
This function creates an instance of a QA bot, passes the query to it, | ||
and returns the bot's response. | ||
Args: | ||
query (str): The question to be answered by the QA bot. | ||
Returns: | ||
dict: The QA bot's response, typically a dictionary with response details. | ||
""" | ||
qa_bot_instance = create_retrieval_qa_bot() | ||
bot_response = qa_bot_instance({"query": query}) | ||
return bot_response | ||
|
||
|
||
@cl.on_chat_start | ||
async def initialize_bot(): | ||
""" | ||
Initializes the bot when a new chat starts. | ||
This asynchronous function creates a new instance of the retrieval QA bot, | ||
sends a welcome message, and stores the bot instance in the user's session. | ||
""" | ||
qa_chain = create_retrieval_qa_bot() | ||
welcome_message = cl.Message(content="Starting the bot...") | ||
await welcome_message.send() | ||
welcome_message.content = ( | ||
"Hi, Welcome to Chat With Documents using Llama2 and LangChain." | ||
) | ||
await welcome_message.update() | ||
|
||
cl.user_session.set("chain", qa_chain) | ||
|
||
|
||
@cl.on_message | ||
async def process_chat_message(message): | ||
""" | ||
Processes incoming chat messages. | ||
This asynchronous function retrieves the QA bot instance from the user's session, | ||
sets up a callback handler for the bot's response, and executes the bot's | ||
call method with the given message and callback. The bot's answer and source | ||
documents are then extracted from the response. | ||
""" | ||
qa_chain = cl.user_session.get("chain") | ||
callback_handler = cl.AsyncLangchainCallbackHandler( | ||
stream_final_answer=True, answer_prefix_tokens=["FINAL", "ANSWER"] | ||
) | ||
callback_handler.answer_reached = True | ||
response = await qa_chain.acall(message, callbacks=[callback_handler]) | ||
bot_answer = response["result"] | ||
source_documents = bot_answer["source_documents"] | ||
|
||
if source_documents: | ||
answer += f"\nSources:" + str(source_documents) # type: ignore | ||
else: | ||
answer += "\nNo sources found" # type: ignore | ||
|
||
await cl.Message(content=answer).send() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
langchain | ||
chainlit | ||
python-dotenv | ||
chromadb | ||
torch | ||
transformers | ||
sentence_transformers | ||
unstructured | ||
pypdf | ||
ctransformers | ||
Markdown |