Skip to content

Commit

Permalink
added ability to chunk by token instead of char
Browse files Browse the repository at this point in the history
  • Loading branch information
snova-imranr committed Feb 29, 2024
1 parent f23d810 commit 3c46ae5
Showing 1 changed file with 32 additions and 5 deletions.
37 changes: 32 additions & 5 deletions vectordb/vector_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

from langchain_community.document_loaders import DirectoryLoader
from langchain_community.embeddings import HuggingFaceInstructEmbeddings
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.text_splitter import RecursiveCharacterTextSplitter, CharacterTextSplitter
from langchain_community.vectorstores import FAISS, Chroma, Qdrant

EMBEDDING_MODEL = "hkunlp/instructor-large"
Expand Down Expand Up @@ -68,8 +68,8 @@ def get_text_chunks(self, docs: list, chunk_size: int, chunk_overlap: int, meta_
Args:
docs (list): list of documents or texts. If no metadata is passed, this parameter is a list of documents.
If metadata is passed, this parameter is a list of texts.
chunk_size (int): chunk size in number of tokens
chunk_overlap (int): chunk overlap in number of tokens
chunk_size (int): chunk size in number of characters
chunk_overlap (int): chunk overlap in number of characters
metadata (list, optional): list of metadata in dictionary format. Defaults to None.
Returns:
Expand All @@ -91,6 +91,30 @@ def get_text_chunks(self, docs: list, chunk_size: int, chunk_overlap: int, meta_

return chunks

def get_token_chunks(self, docs: list, chunk_size: int, chunk_overlap: int, tokenizer) -> list:
"""Gets token chunks. If metadata is not None, it will create chunks with metadata elements.
Args:
docs (list): list of documents or texts. If no metadata is passed, this parameter is a list of documents.
If metadata is passed, this parameter is a list of texts.
chunk_size (int): chunk size in number of tokens
chunk_overlap (int): chunk overlap in number of tokens
Returns:
list: list of documents
"""

text_splitter = CharacterTextSplitter.from_huggingface_tokenizer(
tokenizer, chunk_size=chunk_size, chunk_overlap=chunk_overlap
)

logger.info(f"Splitter: splitting documents")
chunks = text_splitter.split_documents(docs)

logger.info(f"Total {len(chunks)} chunks created")

return chunks

def load_embedding_model(self) -> HuggingFaceInstructEmbeddings:
"""Loads embedding model
Expand Down Expand Up @@ -203,11 +227,14 @@ def update_vdb(self, chunks: list, embeddings: HuggingFaceInstructEmbeddings, db

return vector_store

def create_vdb(self, input_path, chunk_size, chunk_overlap, db_type, output_db=None, recursive=False):
def create_vdb(self, input_path, chunk_size, chunk_overlap, db_type, output_db=None, recursive=False, tokenizer=None):

docs = self.load_files(input_path, recursive=recursive)

chunks = self.get_text_chunks(docs, chunk_size, chunk_overlap)
if tokenizer is None:
chunks = self.get_text_chunks(docs, chunk_size, chunk_overlap)
else:
chunks = self.get_token_chunks(docs, chunk_size, chunk_overlap, tokenizer)

embeddings = self.load_embedding_model()

Expand Down

0 comments on commit 3c46ae5

Please sign in to comment.