Skip to content

Commit

Permalink
Added Google Vertex AI
Browse files Browse the repository at this point in the history
  • Loading branch information
ashwinzyx committed Jul 31, 2024
1 parent 9d6d231 commit 6e39fce
Show file tree
Hide file tree
Showing 5 changed files with 15 additions and 3 deletions.
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -69,4 +69,5 @@ scikit-optimize
pinecone-client
pystemmer
langchain_groq
langchain-google-genai
langchain-google-genai
langchain-google-vertexai
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,8 @@
'pinecone-client',
'setuptools',
'langchain_groq',
'langchain-google-genai'
'langchain-google-genai',
'langchain-google-vertexai'
# other dependencies
],
)
3 changes: 2 additions & 1 deletion src/ragbuilder/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@
from langchain_groq import ChatGroq
from langchain_openai import AzureOpenAIEmbeddings, AzureChatOpenAI
from langchain_google_genai import ChatGoogleGenerativeAI,GoogleGenerativeAIEmbeddings
from langchain_google_vertexai import ChatVertexAI, VertexAIEmbeddings

# import local modules
from ragbuilder.langchain_module.retriever.retriever import *
Expand Down Expand Up @@ -287,8 +288,8 @@ def __init__(self, val):
logger.info("Creating RAG object from generated code...(this may take a while in some cases)")
try:
#execution os string
# logger.info(f"Generated Code\n{self.router}")
exec(self.router,globals_dict,locals_dict)
logger.debug(f"Generated Code\n{self.router}")

#old rag func hooked to eval
self.rag = locals_dict['rag_pipeline']()
Expand Down
5 changes: 5 additions & 0 deletions src/ragbuilder/langchain_module/embedding_model/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ def getEmbedding(**kwargs):
raise KeyError("The key 'embedding_model' is missing from the arguments.")

embedding_model = kwargs['embedding_model']

model_owner= embedding_model.split(":")[0]
model= embedding_model.split(":")[1]
# Validate the embedding model type
Expand All @@ -39,6 +40,10 @@ def getEmbedding(**kwargs):
logger.info(f"Google Embedding Invoked: {embedding_model}")
code_string= f"""embedding = GoogleGenerativeAIEmbeddings(model='{model}')"""
import_string = f"""from langchain_google_genai import GoogleGenerativeAIEmbeddings"""
elif model_owner == "GoogleVertexAI":
logger.info(f"GoogleVertexAI Embedding Invoked: {embedding_model}")
code_string= f"""embedding = VertexAIEmbeddings(model_name='{model}') """
import_string = f"""from langchain_google_genai import GoogleGenerativeAIEmbeddings"""
elif model_owner == "Azure":
logger.info(f"Azure Embedding Invoked: {embedding_model}")
code_string= f"""embedding = AzureOpenAIEmbeddings(model='{model}')"""
Expand Down
4 changes: 4 additions & 0 deletions src/ragbuilder/langchain_module/llms/llmConfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@ def getLLM(**kwargs):
logger.info(f"LLM Code Gen Invoked:Google")
import_string = f"""from langchain_google_genai import ChatGoogleGenerativeAI"""
code_string = f"""llm = ChatGoogleGenerativeAI(model='{model}')"""
elif model_owner == "GoogleVertexAI":
logger.info(f"LLM Code Gen Invoked:GoogleVertexAI")
import_string = f"""from langchain_google_vertexai import ChatVertexAI"""
code_string = f"""llm = ChatVertexAI(model_name='{model}')"""
elif model_owner == "OpenAI":
logger.info(f"LLM Code Gen Invoked: {retrieval_model}")
import_string = f"""from langchain_openai import ChatOpenAI"""
Expand Down

0 comments on commit 6e39fce

Please sign in to comment.