Skip to content

Commit

Permalink
Merge pull request sambanova#106 from snova-jorgep/snova-jorge/enterp…
Browse files Browse the repository at this point in the history
…rise-knowledge-retriever

feat: sambanova's langchain integration usage in entreprise knowledge retriever
  • Loading branch information
snova-amitk authored May 1, 2024
2 parents e563f39 + 51f1ca1 commit 88cd851
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 7 deletions.
2 changes: 2 additions & 0 deletions enterprise_knowledge_retriever/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
streamlit==1.25.0
pydantic==2.7.0
pydantic_core==2.18.1
langchain-community==0.0.35
langchain-core==0.1.47
langchain==0.1.16
sentence_transformers==2.2.2
instructorembedding==1.0.1
Expand Down
12 changes: 6 additions & 6 deletions enterprise_knowledge_retriever/src/document_retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@
from PyPDF2 import PdfReader
from vectordb.vector_db import VectorDb
from data_extraction.src.multi_column import column_boxes
from utils.sambanova_endpoint import SambaNovaEndpoint, SambaverseEndpoint
from langchain.chains import RetrievalQA
from langchain.prompts import PromptTemplate, load_prompt
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import UnstructuredPDFLoader, TextLoader
from langchain_community.llms.sambanova import SambaStudio, Sambaverse

current_dir = os.path.dirname(os.path.abspath(__file__))
kit_dir = os.path.abspath(os.path.join(current_dir, ".."))
Expand Down Expand Up @@ -42,7 +42,7 @@ def get_config_info(self):
with open(CONFIG_PATH, 'r') as yaml_file:
config = yaml.safe_load(yaml_file)
api_info = config["api"]
llm_info = config["llm"]
_info = config["llm"]
embedding_model_info = config["embedding_model"]
retrieval_info = config["retrieval"]
loaders = config["loaders"]
Expand Down Expand Up @@ -218,7 +218,7 @@ def get_qa_retrieval_chain(self, vectorstore):
"""
Generate a qa_retrieval chain using a language model.
This function uses a language model, specifically a SambaNovaEndpoint, to generate a qa_retrieval chain
This function uses a language model, specifically a SambaNova LLM, to generate a qa_retrieval chain
based on the input vector store of text chunks.
Parameters:
Expand All @@ -230,7 +230,7 @@ def get_qa_retrieval_chain(self, vectorstore):
"""

if self.api_info == "sambaverse":
llm = SambaverseEndpoint(
llm = Sambaverse(
sambaverse_model_name=self.llm_info["sambaverse_model_name"],
sambaverse_api_key=os.getenv("SAMBAVERSE_API_KEY"),
model_kwargs={
Expand All @@ -247,7 +247,7 @@ def get_qa_retrieval_chain(self, vectorstore):
)

elif self.api_info == "sambastudio":
llm = SambaNovaEndpoint(
llm = SambaStudio(
model_kwargs={
"do_sample": False,
"temperature": self.llm_info["temperature"],
Expand Down Expand Up @@ -282,7 +282,7 @@ def get_conversational_qa_retrieval_chain(self, vectorstore):
"""
Generate a conversational retrieval qa chain using a language model.
This function uses a language model, specifically a SambaNovaEndpoint, to generate a conversational_qa_retrieval chain
This function uses a language model, specifically a SambaNova LLM, to generate a conversational_qa_retrieval chain
based on the chat history and the relevant retrieved content from the input vector store of text chunks.
Parameters:
Expand Down
2 changes: 1 addition & 1 deletion utils/sambanova_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from langchain_core.pydantic_v1 import Extra, root_validator
from langchain_core.utils import get_from_dict_or_env
from langchain.callbacks.base import BaseCallbackHandler # type: ignore
from langchain.embeddings.base import Embeddings
from langchain_core.embeddings import Embeddings


class SVEndpointHandler:
Expand Down

0 comments on commit 88cd851

Please sign in to comment.