Skip to content

Commit

Permalink
llama2 chat with documents
Browse files Browse the repository at this point in the history
  • Loading branch information
sudarshan-koirala authored Aug 6, 2023
1 parent 2681101 commit 6791573
Show file tree
Hide file tree
Showing 6 changed files with 342 additions and 0 deletions.
41 changes: 41 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# chat-with-website
Simple Chainlit app to have interaction with your documents.

### Chat with your documents 🚀
- [Huggingface model](https://huggingface.co/TheBloke/Llama-2-7B-Chat-GGML/blob/main/llama-2-7b-chat.ggmlv3.q8_0.bin) as Large Language model
- [LangChain](https://python.langchain.com/docs/get_started/introduction.html) as a Framework for LLM
- [Streamlit](https://docs.chainlit.io/overview) for deploying.

## System Requirements

You must have Python 3.9 or later installed. Earlier versions of python may not compile.

---

## Steps to Replicate

1. Fork this repository and create a codespace in GitHub as I showed you in the youtube video OR Clone it locally.
```
git clone https://github.com/sudarshan-koirala/chat-with-website.git
cd chat-with-website
```

2. Rename example.env to .env with `cp example.env .env`and input the OpenAI API key as follows. Get OpenAI API key from this [URL](https://platform.openai.com/account/api-keys). You need to create an account in OpenAI webiste if you haven't already.
```
OPENAI_API_KEY=your_openai_api_key
```

3. Create a virtualenv and activate it
```
python3 -m venv .venv && source .venv/bin/activate
```

4. Run the following command in the terminal to install necessary python packages:
```
pip install -r requirements.txt
```

5. Run the following command in your terminal to start the chat UI:
```
streamlit run chat_with_website.py
```
15 changes: 15 additions & 0 deletions chainlit.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Welcome to Chainlit! 🚀🤖

Hi, This is a simple Chainlit app to have chat with your documents.

## Useful Links 🔗

**Youtube Playlist**
Get started with [Chainlit Playlist](https://youtube.com/playlist?list=PLz-qytj7eIWWNnbCRxflmRbYI02jZeG0k) 🎥
Get started with [LangChain Playlist](https://youtube.com/playlist?list=PLz-qytj7eIWVd1a5SsQ1dzOjVDHdgC1Ck) 🎥

Happy coding! 💻😊

## If you want to support
- [Buy me a coffee](https://ko-fi.com/datasciencebasics) && [Patreon](https://www.patreon.com/datasciencebasics)

1 change: 1 addition & 0 deletions example.env
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
HUGGINGFACEHUB_API_TOKEN=hf_*****
62 changes: 62 additions & 0 deletions ingest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import os

from langchain.document_loaders import (
DirectoryLoader,
PyPDFLoader,
TextLoader,
UnstructuredMarkdownLoader,
)
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores import Chroma

ABS_PATH: str = os.path.dirname(os.path.abspath(__file__))
DB_DIR: str = os.path.join(ABS_PATH, "db")


# Create vector database
def create_vector_database():
"""
Creates a vector database using document loaders and embeddings.
This function loads data from PDF, markdown and text files in the 'data/' directory,
splits the loaded documents into chunks, transforms them into embeddings using HuggingFace,
and finally persists the embeddings into a Chroma vector database.
"""
# Initialize loaders for different file types
pdf_loader = DirectoryLoader("data/", glob="**/*.pdf", loader_cls=PyPDFLoader)
markdown_loader = DirectoryLoader(
"data/", glob="**/*.md", loader_cls=UnstructuredMarkdownLoader
)
text_loader = DirectoryLoader("data/", glob="**/*.txt", loader_cls=TextLoader)

all_loaders = [pdf_loader, markdown_loader, text_loader]

# Load documents from all loaders
loaded_documents = []
for loader in all_loaders:
loaded_documents.extend(loader.load())

# Split loaded documents into chunks
text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=40)
chunked_documents = text_splitter.split_documents(loaded_documents)

# Initialize HuggingFace embeddings
huggingface_embeddings = HuggingFaceEmbeddings(
model_name="sentence-transformers/all-MiniLM-L6-v2",
model_kwargs={"device": "cpu"},
)

# Create and persist a Chroma vector database from the chunked documents
vector_database = Chroma.from_documents(
documents=chunked_documents,
embedding=huggingface_embeddings,
persist_directory=DB_DIR,
)

vector_database.persist()


if __name__ == "__main__":
create_vector_database()
212 changes: 212 additions & 0 deletions main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,212 @@
import os

import chainlit as cl
from langchain.chains import RetrievalQA
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.llms import CTransformers
from langchain.prompts import PromptTemplate
from langchain.vectorstores import Chroma

prompt_template = """Use the following pieces of context to answer the users question.
If you don't know the answer, just say that you don't know, don't try to make up an answer.
ALWAYS return a "SOURCES" part in your answer.
The "SOURCES" part should be a reference to the source of the document from which you got your answer.
xample of your response should be:
Context: {context}
Question: {question}
Only return the helpful answer below and nothing else.
Helpful answer:
"""


def set_custom_prompt():
"""
Prompt template for QA retrieval for each vectorstore
"""
prompt = PromptTemplate(
template=prompt_template, input_variables=["context", "question"]
)
return prompt


def create_retrieval_qa_chain(llm, prompt, db):
"""
Creates a Retrieval Question-Answering (QA) chain using a given language model, prompt, and database.
This function initializes a RetrievalQA object with a specific chain type and configurations,
and returns this QA chain. The retriever is set up to return the top 3 results (k=3).
Args:
llm (any): The language model to be used in the RetrievalQA.
prompt (str): The prompt to be used in the chain type.
db (any): The database to be used as the retriever.
Returns:
RetrievalQA: The initialized QA chain.
"""
qa_chain = RetrievalQA.from_chain_type(
llm=llm,
chain_type="stuff",
retriever=db.as_retriever(search_kwargs={"k": 3}),
return_source_documents=True,
chain_type_kwargs={"prompt": prompt},
)
return qa_chain


def load_model(
model_path="model/llama-2-7b-chat.ggmlv3.q8_0.bin",
model_type="llama",
max_new_tokens=512,
temperature=0.5,
):
"""
Load a locally downloaded model.
Parameters:
model_path (str): The path to the model to be loaded.
model_type (str): The type of the model.
max_new_tokens (int): The maximum number of new tokens for the model.
temperature (float): The temperature parameter for the model.
Returns:
CTransformers: The loaded model.
Raises:
FileNotFoundError: If the model file does not exist.
SomeOtherException: If the model file is corrupt.
"""
if not os.path.exists(model_path):
raise FileNotFoundError(f"No model file found at {model_path}")

# Additional error handling could be added here for corrupt files, etc.

llm = CTransformers(
model=model_path,
model_type=model_type,
max_new_tokens=max_new_tokens, # type: ignore
temperature=temperature, # type: ignore
)

return llm


def create_retrieval_qa_bot(
model_name="sentence-transformers/all-MiniLM-L6-v2",
persist_dir="./db",
device="cpu",
):
"""
This function creates a retrieval-based question-answering bot.
Parameters:
model_name (str): The name of the model to be used for embeddings.
persist_dir (str): The directory to persist the database.
device (str): The device to run the model on (e.g., 'cpu', 'cuda').
Returns:
RetrievalQA: The retrieval-based question-answering bot.
Raises:
FileNotFoundError: If the persist directory does not exist.
SomeOtherException: If there is an issue with loading the embeddings or the model.
"""

if not os.path.exists(persist_dir):
raise FileNotFoundError(f"No directory found at {persist_dir}")

try:
embeddings = HuggingFaceEmbeddings(
model_name=model_name,
model_kwargs={"device": device},
)
except Exception as e:
raise Exception(
f"Failed to load embeddings with model name {model_name}: {str(e)}"
)

db = Chroma(persist_directory=persist_dir, embedding_function=embeddings)

try:
llm = load_model() # Assuming this function exists and works as expected
except Exception as e:
raise Exception(f"Failed to load model: {str(e)}")

qa_prompt = (
set_custom_prompt()
) # Assuming this function exists and works as expected

try:
qa = create_retrieval_qa_chain(
llm=llm, prompt=qa_prompt, db=db
) # Assuming this function exists and works as expected
except Exception as e:
raise Exception(f"Failed to create retrieval QA chain: {str(e)}")

return qa


def retrieve_bot_answer(query):
"""
Retrieves the answer to a given query using a QA bot.
This function creates an instance of a QA bot, passes the query to it,
and returns the bot's response.
Args:
query (str): The question to be answered by the QA bot.
Returns:
dict: The QA bot's response, typically a dictionary with response details.
"""
qa_bot_instance = create_retrieval_qa_bot()
bot_response = qa_bot_instance({"query": query})
return bot_response


@cl.on_chat_start
async def initialize_bot():
"""
Initializes the bot when a new chat starts.
This asynchronous function creates a new instance of the retrieval QA bot,
sends a welcome message, and stores the bot instance in the user's session.
"""
qa_chain = create_retrieval_qa_bot()
welcome_message = cl.Message(content="Starting the bot...")
await welcome_message.send()
welcome_message.content = (
"Hi, Welcome to Chat With Documents using Llama2 and LangChain."
)
await welcome_message.update()

cl.user_session.set("chain", qa_chain)


@cl.on_message
async def process_chat_message(message):
"""
Processes incoming chat messages.
This asynchronous function retrieves the QA bot instance from the user's session,
sets up a callback handler for the bot's response, and executes the bot's
call method with the given message and callback. The bot's answer and source
documents are then extracted from the response.
"""
qa_chain = cl.user_session.get("chain")
callback_handler = cl.AsyncLangchainCallbackHandler(
stream_final_answer=True, answer_prefix_tokens=["FINAL", "ANSWER"]
)
callback_handler.answer_reached = True
response = await qa_chain.acall(message, callbacks=[callback_handler])
bot_answer = response["result"]
source_documents = bot_answer["source_documents"]

if source_documents:
answer += f"\nSources:" + str(source_documents) # type: ignore
else:
answer += "\nNo sources found" # type: ignore

await cl.Message(content=answer).send()
11 changes: 11 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
langchain
chainlit
python-dotenv
chromadb
torch
transformers
sentence_transformers
unstructured
pypdf
ctransformers
Markdown

0 comments on commit 6791573

Please sign in to comment.