Skip to content

Commit

Permalink
Added cpu support
Browse files Browse the repository at this point in the history
Added support for cpu. Now can use the --device_type flag to use CPU for both ingest.py and run_localGPU.py
  • Loading branch information
PromtEngineer committed May 29, 2023
1 parent 0b19928 commit 5a7375b
Show file tree
Hide file tree
Showing 5 changed files with 32 additions and 14 deletions.
Binary file added __pycache__/constants.cpython-310.pyc
Binary file not shown.
6 changes: 4 additions & 2 deletions constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@
from chromadb.config import Settings

# load_dotenv()
PERSIST_DIRECTORY = os.path.dirname(os.path.realpath(__file__))
ROOT_DIRECTORY = os.path.dirname(os.path.realpath(__file__))

# Define the folder for storing database
SOURCE_DIRECTORY = f"{PERSIST_DIRECTORY}/SOURCE_DOCUMENTS"
SOURCE_DIRECTORY = f"{ROOT_DIRECTORY}/SOURCE_DOCUMENTS"

PERSIST_DIRECTORY = f"{ROOT_DIRECTORY}/DB"

# Define the Chroma settings
CHROMA_SETTINGS = Settings(
Expand Down
13 changes: 10 additions & 3 deletions ingest.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import click
from typing import List

from langchain.document_loaders import TextLoader, PDFMinerLoader, CSVLoader
Expand All @@ -22,12 +23,18 @@ def load_single_document(file_path: str) -> Document:

def load_documents(source_dir: str) -> List[Document]:
# Loads all documents from source documents directory
# docs_path = f"/privateGPT/{source_dir}" # replace this with the absolute path of the source_documents folder
all_files = os.listdir(source_dir)
return [load_single_document(f"{source_dir}/{file_path}") for file_path in all_files if file_path[-4:] in ['.txt', '.pdf', '.csv'] ]


def main():
@click.command()
@click.option('--device_type', default='gpu', help='device to run on, select gpu or cpu')
def main(device_type, ):
# load the instructorEmbeddings
if device_type in ['cpu', 'CPU']:
device='cpu'
else:
device='cuda'

# Load documents and split in chunks
print(f"Loading documents from {SOURCE_DIRECTORY}")
Expand All @@ -39,7 +46,7 @@ def main():

# Create embeddings
embeddings = HuggingFaceInstructEmbeddings(model_name="hkunlp/instructor-xl",
model_kwargs={"device": "cuda"})
model_kwargs={"device": device})

db = Chroma.from_documents(texts, embeddings, persist_directory=PERSIST_DIRECTORY, client_settings=CHROMA_SETTINGS)
db.persist()
Expand Down
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,5 @@ huggingface_hub
transformers
protobuf==3.20.0
accelerate
bitsandbytes
bitsandbytes
click
24 changes: 16 additions & 8 deletions run_localGPT.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from langchain.chains import RetrievalQA
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
# from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from langchain.vectorstores import Chroma
from langchain.embeddings import HuggingFaceInstructEmbeddings
from langchain.llms import HuggingFacePipeline
from constants import CHROMA_SETTINGS, SOURCE_DIRECTORY, PERSIST_DIRECTORY
from transformers import LlamaTokenizer, LlamaForCausalLM, GenerationConfig, pipeline

from constants import CHROMA_SETTINGS, PERSIST_DIRECTORY
from transformers import LlamaTokenizer, LlamaForCausalLM, pipeline
import click

from constants import CHROMA_SETTINGS

Expand Down Expand Up @@ -39,16 +39,24 @@ def load_model():

return local_llm


def main():
@click.command()
@click.option('--device_type', default='gpu', help='device to run on, select gpu or cpu')
def main(device_type, ):
# load the instructorEmbeddings
if device_type in ['cpu', 'CPU']:
device='cpu'
else:
device='cuda'

print(f"Running on: {device}")

embeddings = HuggingFaceInstructEmbeddings(model_name="hkunlp/instructor-xl",
model_kwargs={"device": "cuda"})
model_kwargs={"device": device})
# load the vectorstore
db = Chroma(persist_directory=PERSIST_DIRECTORY, embedding_function=embeddings, client_settings=CHROMA_SETTINGS)
retriever = db.as_retriever()
# Prepare the LLM
callbacks = [StreamingStdOutCallbackHandler()]
# callbacks = [StreamingStdOutCallbackHandler()]
# load the LLM for generating Natural Language responses.
llm = load_model()
qa = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=retriever, return_source_documents=True)
Expand Down

0 comments on commit 5a7375b

Please sign in to comment.