Skip to content

Commit

Permalink
refactor: Refactor document loading and add device selection
Browse files Browse the repository at this point in the history
Signed-off-by: teleprint-me <[email protected]>
  • Loading branch information
teleprint-me committed Jun 4, 2023
1 parent 672e34f commit a4c0fee
Showing 1 changed file with 38 additions and 39 deletions.
77 changes: 38 additions & 39 deletions ingest.py
Original file line number Diff line number Diff line change
@@ -1,55 +1,46 @@
import os
import click
from typing import List

from langchain.document_loaders import TextLoader, PDFMinerLoader, CSVLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores import Chroma
import click
from langchain.docstore.document import Document
from constants import CHROMA_SETTINGS, SOURCE_DIRECTORY, PERSIST_DIRECTORY
from langchain.embeddings import HuggingFaceInstructEmbeddings
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores import Chroma

from constants import CHROMA_SETTINGS, DOCUMENT_MAP, PERSIST_DIRECTORY, SOURCE_DIRECTORY


def load_single_document(file_path: str) -> Document:
# Loads a single document from a file path
if file_path.endswith(".txt"):
loader = TextLoader(file_path, encoding="utf8")
elif file_path.endswith(".pdf"):
loader = PDFMinerLoader(file_path)
elif file_path.endswith(".csv"):
loader = CSVLoader(file_path)
file_extension = os.path.splitext(file_path)[1]
loader_class = DOCUMENT_MAP.get(file_extension)
if loader_class:
loader = loader_class(file_path)
else:
raise ValueError("Document type is undefined")
return loader.load()[0]



def load_documents(source_dir: str) -> List[Document]:
# Loads all documents from source documents directory
# Loads all documents from the source documents directory
all_files = os.listdir(source_dir)
return [load_single_document(f"{source_dir}/{file_path}") for file_path in all_files if file_path[-4:] in ['.txt', '.pdf', '.csv'] ]


# @click.command()
# @click.option('--device_type', default='gpu', help='device to run on, select gpu or cpu')
# def main(device_type, ):
# # load the instructorEmbeddings
# if device_type in ['cpu', 'CPU']:
# device='cpu'
# else:
# device='cuda'
return [
load_single_document(os.path.join(source_dir, file_path))
for file_path in all_files
if os.path.splitext(file_path)[1] in DOCUMENT_MAP.keys()
]


@click.command()
@click.option('--device_type', default='cuda', help='device to run on, select gpu, cpu or mps')
def main(device_type, ):
# load the instructorEmbeddings
if device_type in ['cpu', 'CPU']:
device='cpu'
elif device_type in ['mps', 'MPS']:
device='mps'
else:
device='cuda'


# Load documents and split in chunks
@click.option(
"--device_type",
default="cuda",
type=click.Choice(["cpu", "cuda", "ipu", "xpu", "mkldnn", "opengl", "opencl", "ideep", "hip", "ve", "fpga", "ort", "xla", "lazy", "vulkan", "mps", "meta", "hpu", "mtia"]),
help="Device to run on. (Default is cuda)",
)
def main(device_type):
# Load documents and split in chunks
print(f"Loading documents from {SOURCE_DIRECTORY}")
documents = load_documents(SOURCE_DIRECTORY)
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
Expand All @@ -58,13 +49,21 @@ def main(device_type, ):
print(f"Split into {len(texts)} chunks of text")

# Create embeddings
embeddings = HuggingFaceInstructEmbeddings(model_name="hkunlp/instructor-xl",
model_kwargs={"device": device})

db = Chroma.from_documents(texts, embeddings, persist_directory=PERSIST_DIRECTORY, client_settings=CHROMA_SETTINGS)
embeddings = HuggingFaceInstructEmbeddings(
model_name="hkunlp/instructor-xl",
model_kwargs={"device": device_type},
)

db = Chroma.from_documents(
texts,
embeddings,
persist_directory=PERSIST_DIRECTORY,
client_settings=CHROMA_SETTINGS,
)
db.persist()
db = None


if __name__ == "__main__":
main()
main()

0 comments on commit a4c0fee

Please sign in to comment.