Skip to content

Commit

Permalink
Merge pull request PromtEngineer#702 from BBC-Esq/update-langchain-em…
Browse files Browse the repository at this point in the history
…bedding-classes

automatic correct langchain library
  • Loading branch information
PromtEngineer authored Feb 3, 2024
2 parents 040c69a + 747a9b4 commit 8450efc
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 16 deletions.
43 changes: 30 additions & 13 deletions ingest.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,26 +153,43 @@ def main(device_type):
logging.info(f"Loaded {len(documents)} documents from {SOURCE_DIRECTORY}")
logging.info(f"Split into {len(texts)} chunks of text")

# Create embeddings
embeddings = HuggingFaceInstructEmbeddings(
model_name=EMBEDDING_MODEL_NAME,
model_kwargs={"device": device_type},
)
# change the embedding type here if you are running into issues.
# These are much smaller embeddings and will work for most appications
# If you use HuggingFaceEmbeddings, make sure to also use the same in the
# run_localGPT.py file.

# embeddings = HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL_NAME)
"""
(1) Chooses an appropriate langchain library based on the enbedding model name. Matching code is contained within fun_localGPT.py.
(2) Provides additional arguments for instructor and BGE models to improve results, pursuant to the instructions contained on
their respective huggingface repository, project page or github repository.
"""

if "instructor" in EMBEDDING_MODEL_NAME:
return HuggingFaceInstructEmbeddings(
model_name=EMBEDDING_MODEL_NAME,
model_kwargs={"device": compute_device},
embed_instruction='Represent the document for retrieval:',
query_instruction='Represent the question for retrieving supporting documents:'
)

elif "bge" in EMBEDDING_MODEL_NAME:
query_instruction = 'Represent this sentence for searching relevant passages:'

return HuggingFaceBgeEmbeddings(
model_name=EMBEDDING_MODEL_NAME,
model_kwargs={"device": compute_device},
query_instruction='Represent this sentence for searching relevant passages:'
)

else:

return HuggingFaceEmbeddings(
model_name=EMBEDDING_MODEL_NAME,
model_kwargs={"device": compute_device},
)

db = Chroma.from_documents(
texts,
embeddings,
persist_directory=PERSIST_DIRECTORY,
client_settings=CHROMA_SETTINGS,
)



if __name__ == "__main__":
logging.basicConfig(
Expand Down
30 changes: 27 additions & 3 deletions run_localGPT.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,9 +119,33 @@ def retrieval_qa_pipline(device_type, use_history, promptTemplate_type="llama"):
- The QA system retrieves relevant documents using the retriever and then answers questions based on those documents.
"""

embeddings = HuggingFaceInstructEmbeddings(model_name=EMBEDDING_MODEL_NAME, model_kwargs={"device": device_type})
# uncomment the following line if you used HuggingFaceEmbeddings in the ingest.py
# embeddings = HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL_NAME)
"""
(1) Chooses an appropriate langchain library based on the enbedding model name. Matching code is contained within ingest.py.
(2) Provides additional arguments for instructor and BGE models to improve results, pursuant to the instructions contained on
their respective huggingface repository, project page or github repository.
"""

if "instructor" in EMBEDDING_MODEL_NAME:
return HuggingFaceInstructEmbeddings(
model_name=EMBEDDING_MODEL_NAME,
model_kwargs={"device": compute_device},
embed_instruction='Represent the document for retrieval:',
query_instruction='Represent the question for retrieving supporting documents:'
)

elif "bge" in EMBEDDING_MODEL_NAME:
return HuggingFaceBgeEmbeddings(
model_name=EMBEDDING_MODEL_NAME,
model_kwargs={"device": compute_device},
query_instruction='Represent this sentence for searching relevant passages:'
)

else:
return HuggingFaceEmbeddings(
model_name=EMBEDDING_MODEL_NAME,
model_kwargs={"device": compute_device},
)

# load the vectorstore
db = Chroma(
Expand Down

0 comments on commit 8450efc

Please sign in to comment.